This commit is contained in:
Zchen
2025-10-12 09:35:26 +08:00
parent 21d901cfd8
commit 6828ff536e
2 changed files with 227 additions and 12 deletions

View File

@@ -22,7 +22,7 @@ torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on som
torch.backends.cudnn.deterministic = True # makes training more reproducible
torch._dynamo.config.cache_size_limit = 64
from rnn_model import GRUDecoder
from rnn_model import TripleGRUDecoder
class BrainToTextDecoder_Trainer:
"""
@@ -116,22 +116,23 @@ class BrainToTextDecoder_Trainer:
random.seed(self.args['seed'])
torch.manual_seed(self.args['seed'])
# Initialize the model
self.model = GRUDecoder(
# Initialize the model
self.model = TripleGRUDecoder(
neural_dim = self.args['model']['n_input_features'],
n_units = self.args['model']['n_units'],
n_days = len(self.args['dataset']['sessions']),
n_classes = self.args['dataset']['n_classes'],
rnn_dropout = self.args['model']['rnn_dropout'],
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
n_layers = self.args['model']['n_layers'],
rnn_dropout = self.args['model']['rnn_dropout'],
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
patch_size = self.args['model']['patch_size'],
patch_stride = self.args['model']['patch_stride'],
)
# Call torch.compile to speed up training
self.logger.info("Using torch.compile")
self.model = torch.compile(self.model)
# Temporarily disable torch.compile for compatibility with new model architecture
# TODO: Re-enable torch.compile once model is stable
# self.logger.info("Using torch.compile")
# self.model = torch.compile(self.model)
self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
self.logger.info(f"Initialized RNN decoding model")
@@ -531,8 +532,9 @@ class BrainToTextDecoder_Trainer:
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Get phoneme predictions
logits = self.model(features, day_indicies)
# Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
logits = self.model(features, day_indicies, None, False, 'inference')
# Calculate CTC Loss
loss = self.ctc_loss(
@@ -706,7 +708,7 @@ class BrainToTextDecoder_Trainer:
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
logits = self.model(features, day_indicies)
logits = self.model(features, day_indicies, None, False, 'inference')
loss = self.ctc_loss(
torch.permute(logits.log_softmax(2), [1, 0, 2]),