f
This commit is contained in:
@@ -98,6 +98,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# This ensures we're in the correct replica context
|
# This ensures we're in the correct replica context
|
||||||
print("🔧 Creating optimizer slot variables within TPU replica context...")
|
print("🔧 Creating optimizer slot variables within TPU replica context...")
|
||||||
|
|
||||||
|
@tf.function
|
||||||
def init_optimizer_slots():
|
def init_optimizer_slots():
|
||||||
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
|
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
|
||||||
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
|
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
|
||||||
|
Reference in New Issue
Block a user