refactor: streamline model building and ensure dtype consistency in L2 loss calculation
This commit is contained in:
@@ -90,12 +90,8 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
with self.strategy.scope():
|
with self.strategy.scope():
|
||||||
print("🔨 Building model within TPU strategy scope...")
|
print("🔨 Building model within TPU strategy scope...")
|
||||||
self.model = self._build_model()
|
self.model = self._build_model()
|
||||||
print("✅ Model built successfully")
|
|
||||||
|
|
||||||
print("⚙️ Creating optimizer...")
|
print("⚙️ Creating optimizer...")
|
||||||
self.optimizer = self._create_optimizer()
|
self.optimizer = self._create_optimizer()
|
||||||
print("✅ Optimizer created")
|
|
||||||
|
|
||||||
print("🔧 Pre-building optimizer state for TPU...")
|
print("🔧 Pre-building optimizer state for TPU...")
|
||||||
# For TPU, we must ensure optimizer is completely ready before training
|
# For TPU, we must ensure optimizer is completely ready before training
|
||||||
# since @tf.function doesn't allow dynamic building
|
# since @tf.function doesn't allow dynamic building
|
||||||
@@ -595,7 +591,10 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
if self.manual_weight_decay:
|
if self.manual_weight_decay:
|
||||||
l2_loss = tf.constant(0.0, dtype=loss.dtype)
|
l2_loss = tf.constant(0.0, dtype=loss.dtype)
|
||||||
for var in self.model.trainable_variables:
|
for var in self.model.trainable_variables:
|
||||||
l2_loss += tf.nn.l2_loss(var)
|
# Ensure dtype consistency for mixed precision training
|
||||||
|
var_l2 = tf.nn.l2_loss(var)
|
||||||
|
var_l2 = tf.cast(var_l2, dtype=loss.dtype) # Cast to match loss dtype
|
||||||
|
l2_loss += var_l2
|
||||||
loss += self.weight_decay_rate * l2_loss
|
loss += self.weight_decay_rate * l2_loss
|
||||||
|
|
||||||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||||||
|
Reference in New Issue
Block a user