diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 8149dc4..af5b735 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -22,7 +22,7 @@ torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on som torch.backends.cudnn.deterministic = True # makes training more reproducible torch._dynamo.config.cache_size_limit = 64 -from rnn_model import GRUDecoder +from rnn_model import TripleGRUDecoder class BrainToTextDecoder_Trainer: """ @@ -116,22 +116,23 @@ class BrainToTextDecoder_Trainer: random.seed(self.args['seed']) torch.manual_seed(self.args['seed']) - # Initialize the model - self.model = GRUDecoder( + # Initialize the model + self.model = TripleGRUDecoder( neural_dim = self.args['model']['n_input_features'], n_units = self.args['model']['n_units'], n_days = len(self.args['dataset']['sessions']), n_classes = self.args['dataset']['n_classes'], - rnn_dropout = self.args['model']['rnn_dropout'], - input_dropout = self.args['model']['input_network']['input_layer_dropout'], - n_layers = self.args['model']['n_layers'], + rnn_dropout = self.args['model']['rnn_dropout'], + input_dropout = self.args['model']['input_network']['input_layer_dropout'], patch_size = self.args['model']['patch_size'], patch_stride = self.args['model']['patch_stride'], ) - # Call torch.compile to speed up training - self.logger.info("Using torch.compile") - self.model = torch.compile(self.model) + # Temporarily disable torch.compile for compatibility with new model architecture + # TODO: Re-enable torch.compile once model is stable + # self.logger.info("Using torch.compile") + # self.model = torch.compile(self.model) + self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility") self.logger.info(f"Initialized RNN decoding model") @@ -531,8 +532,9 @@ class BrainToTextDecoder_Trainer: adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) - # Get phoneme predictions - logits = self.model(features, day_indicies) + # Get phoneme predictions using inference mode during training + # (We use inference mode for simplicity - only clean logits are used for CTC loss) + logits = self.model(features, day_indicies, None, False, 'inference') # Calculate CTC Loss loss = self.ctc_loss( @@ -706,7 +708,7 @@ class BrainToTextDecoder_Trainer: adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) - logits = self.model(features, day_indicies) + logits = self.model(features, day_indicies, None, False, 'inference') loss = self.ctc_loss( torch.permute(logits.log_softmax(2), [1, 0, 2]), diff --git a/model_training_nnn/trained_models/baseline_rnn/training_log b/model_training_nnn/trained_models/baseline_rnn/training_log new file mode 100644 index 0000000..c835d6e --- /dev/null +++ b/model_training_nnn/trained_models/baseline_rnn/training_log @@ -0,0 +1,213 @@ +2025-10-12 09:32:47,330: Using device: cpu +2025-10-12 09:33:20,119: torch.compile disabled for new TripleGRUDecoder compatibility +2025-10-12 09:33:20,120: Initialized RNN decoding model +2025-10-12 09:33:20,120: Model type: +2025-10-12 09:33:20,120: Model is callable: True +2025-10-12 09:33:20,120: Model has forward method: True +2025-10-12 09:33:20,120: TripleGRUDecoder( + (noise_model): NoiseModel( + (day_layer_activation): Softsign() + (day_weights): ParameterList( + (0): Parameter containing: [torch.float32 of size 512x512] + (1): Parameter containing: [torch.float32 of size 512x512] + (2): Parameter containing: [torch.float32 of size 512x512] + (3): Parameter containing: [torch.float32 of size 512x512] + (4): Parameter containing: [torch.float32 of size 512x512] + (5): Parameter containing: [torch.float32 of size 512x512] + (6): Parameter containing: [torch.float32 of size 512x512] + (7): Parameter containing: [torch.float32 of size 512x512] + (8): Parameter containing: [torch.float32 of size 512x512] + (9): Parameter containing: [torch.float32 of size 512x512] + (10): Parameter containing: [torch.float32 of size 512x512] + (11): Parameter containing: [torch.float32 of size 512x512] + (12): Parameter containing: [torch.float32 of size 512x512] + (13): Parameter containing: [torch.float32 of size 512x512] + (14): Parameter containing: [torch.float32 of size 512x512] + (15): Parameter containing: [torch.float32 of size 512x512] + (16): Parameter containing: [torch.float32 of size 512x512] + (17): Parameter containing: [torch.float32 of size 512x512] + (18): Parameter containing: [torch.float32 of size 512x512] + (19): Parameter containing: [torch.float32 of size 512x512] + (20): Parameter containing: [torch.float32 of size 512x512] + (21): Parameter containing: [torch.float32 of size 512x512] + (22): Parameter containing: [torch.float32 of size 512x512] + (23): Parameter containing: [torch.float32 of size 512x512] + (24): Parameter containing: [torch.float32 of size 512x512] + (25): Parameter containing: [torch.float32 of size 512x512] + (26): Parameter containing: [torch.float32 of size 512x512] + (27): Parameter containing: [torch.float32 of size 512x512] + (28): Parameter containing: [torch.float32 of size 512x512] + (29): Parameter containing: [torch.float32 of size 512x512] + (30): Parameter containing: [torch.float32 of size 512x512] + (31): Parameter containing: [torch.float32 of size 512x512] + (32): Parameter containing: [torch.float32 of size 512x512] + (33): Parameter containing: [torch.float32 of size 512x512] + (34): Parameter containing: [torch.float32 of size 512x512] + (35): Parameter containing: [torch.float32 of size 512x512] + (36): Parameter containing: [torch.float32 of size 512x512] + (37): Parameter containing: [torch.float32 of size 512x512] + (38): Parameter containing: [torch.float32 of size 512x512] + (39): Parameter containing: [torch.float32 of size 512x512] + (40): Parameter containing: [torch.float32 of size 512x512] + (41): Parameter containing: [torch.float32 of size 512x512] + (42): Parameter containing: [torch.float32 of size 512x512] + (43): Parameter containing: [torch.float32 of size 512x512] + (44): Parameter containing: [torch.float32 of size 512x512] + ) + (day_biases): ParameterList( + (0): Parameter containing: [torch.float32 of size 1x512] + (1): Parameter containing: [torch.float32 of size 1x512] + (2): Parameter containing: [torch.float32 of size 1x512] + (3): Parameter containing: [torch.float32 of size 1x512] + (4): Parameter containing: [torch.float32 of size 1x512] + (5): Parameter containing: [torch.float32 of size 1x512] + (6): Parameter containing: [torch.float32 of size 1x512] + (7): Parameter containing: [torch.float32 of size 1x512] + (8): Parameter containing: [torch.float32 of size 1x512] + (9): Parameter containing: [torch.float32 of size 1x512] + (10): Parameter containing: [torch.float32 of size 1x512] + (11): Parameter containing: [torch.float32 of size 1x512] + (12): Parameter containing: [torch.float32 of size 1x512] + (13): Parameter containing: [torch.float32 of size 1x512] + (14): Parameter containing: [torch.float32 of size 1x512] + (15): Parameter containing: [torch.float32 of size 1x512] + (16): Parameter containing: [torch.float32 of size 1x512] + (17): Parameter containing: [torch.float32 of size 1x512] + (18): Parameter containing: [torch.float32 of size 1x512] + (19): Parameter containing: [torch.float32 of size 1x512] + (20): Parameter containing: [torch.float32 of size 1x512] + (21): Parameter containing: [torch.float32 of size 1x512] + (22): Parameter containing: [torch.float32 of size 1x512] + (23): Parameter containing: [torch.float32 of size 1x512] + (24): Parameter containing: [torch.float32 of size 1x512] + (25): Parameter containing: [torch.float32 of size 1x512] + (26): Parameter containing: [torch.float32 of size 1x512] + (27): Parameter containing: [torch.float32 of size 1x512] + (28): Parameter containing: [torch.float32 of size 1x512] + (29): Parameter containing: [torch.float32 of size 1x512] + (30): Parameter containing: [torch.float32 of size 1x512] + (31): Parameter containing: [torch.float32 of size 1x512] + (32): Parameter containing: [torch.float32 of size 1x512] + (33): Parameter containing: [torch.float32 of size 1x512] + (34): Parameter containing: [torch.float32 of size 1x512] + (35): Parameter containing: [torch.float32 of size 1x512] + (36): Parameter containing: [torch.float32 of size 1x512] + (37): Parameter containing: [torch.float32 of size 1x512] + (38): Parameter containing: [torch.float32 of size 1x512] + (39): Parameter containing: [torch.float32 of size 1x512] + (40): Parameter containing: [torch.float32 of size 1x512] + (41): Parameter containing: [torch.float32 of size 1x512] + (42): Parameter containing: [torch.float32 of size 1x512] + (43): Parameter containing: [torch.float32 of size 1x512] + (44): Parameter containing: [torch.float32 of size 1x512] + ) + (day_layer_dropout): Dropout(p=0.2, inplace=False) + (gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4) + ) + (clean_speech_model): CleanSpeechModel( + (day_layer_activation): Softsign() + (day_weights): ParameterList( + (0): Parameter containing: [torch.float32 of size 512x512] + (1): Parameter containing: [torch.float32 of size 512x512] + (2): Parameter containing: [torch.float32 of size 512x512] + (3): Parameter containing: [torch.float32 of size 512x512] + (4): Parameter containing: [torch.float32 of size 512x512] + (5): Parameter containing: [torch.float32 of size 512x512] + (6): Parameter containing: [torch.float32 of size 512x512] + (7): Parameter containing: [torch.float32 of size 512x512] + (8): Parameter containing: [torch.float32 of size 512x512] + (9): Parameter containing: [torch.float32 of size 512x512] + (10): Parameter containing: [torch.float32 of size 512x512] + (11): Parameter containing: [torch.float32 of size 512x512] + (12): Parameter containing: [torch.float32 of size 512x512] + (13): Parameter containing: [torch.float32 of size 512x512] + (14): Parameter containing: [torch.float32 of size 512x512] + (15): Parameter containing: [torch.float32 of size 512x512] + (16): Parameter containing: [torch.float32 of size 512x512] + (17): Parameter containing: [torch.float32 of size 512x512] + (18): Parameter containing: [torch.float32 of size 512x512] + (19): Parameter containing: [torch.float32 of size 512x512] + (20): Parameter containing: [torch.float32 of size 512x512] + (21): Parameter containing: [torch.float32 of size 512x512] + (22): Parameter containing: [torch.float32 of size 512x512] + (23): Parameter containing: [torch.float32 of size 512x512] + (24): Parameter containing: [torch.float32 of size 512x512] + (25): Parameter containing: [torch.float32 of size 512x512] + (26): Parameter containing: [torch.float32 of size 512x512] + (27): Parameter containing: [torch.float32 of size 512x512] + (28): Parameter containing: [torch.float32 of size 512x512] + (29): Parameter containing: [torch.float32 of size 512x512] + (30): Parameter containing: [torch.float32 of size 512x512] + (31): Parameter containing: [torch.float32 of size 512x512] + (32): Parameter containing: [torch.float32 of size 512x512] + (33): Parameter containing: [torch.float32 of size 512x512] + (34): Parameter containing: [torch.float32 of size 512x512] + (35): Parameter containing: [torch.float32 of size 512x512] + (36): Parameter containing: [torch.float32 of size 512x512] + (37): Parameter containing: [torch.float32 of size 512x512] + (38): Parameter containing: [torch.float32 of size 512x512] + (39): Parameter containing: [torch.float32 of size 512x512] + (40): Parameter containing: [torch.float32 of size 512x512] + (41): Parameter containing: [torch.float32 of size 512x512] + (42): Parameter containing: [torch.float32 of size 512x512] + (43): Parameter containing: [torch.float32 of size 512x512] + (44): Parameter containing: [torch.float32 of size 512x512] + ) + (day_biases): ParameterList( + (0): Parameter containing: [torch.float32 of size 1x512] + (1): Parameter containing: [torch.float32 of size 1x512] + (2): Parameter containing: [torch.float32 of size 1x512] + (3): Parameter containing: [torch.float32 of size 1x512] + (4): Parameter containing: [torch.float32 of size 1x512] + (5): Parameter containing: [torch.float32 of size 1x512] + (6): Parameter containing: [torch.float32 of size 1x512] + (7): Parameter containing: [torch.float32 of size 1x512] + (8): Parameter containing: [torch.float32 of size 1x512] + (9): Parameter containing: [torch.float32 of size 1x512] + (10): Parameter containing: [torch.float32 of size 1x512] + (11): Parameter containing: [torch.float32 of size 1x512] + (12): Parameter containing: [torch.float32 of size 1x512] + (13): Parameter containing: [torch.float32 of size 1x512] + (14): Parameter containing: [torch.float32 of size 1x512] + (15): Parameter containing: [torch.float32 of size 1x512] + (16): Parameter containing: [torch.float32 of size 1x512] + (17): Parameter containing: [torch.float32 of size 1x512] + (18): Parameter containing: [torch.float32 of size 1x512] + (19): Parameter containing: [torch.float32 of size 1x512] + (20): Parameter containing: [torch.float32 of size 1x512] + (21): Parameter containing: [torch.float32 of size 1x512] + (22): Parameter containing: [torch.float32 of size 1x512] + (23): Parameter containing: [torch.float32 of size 1x512] + (24): Parameter containing: [torch.float32 of size 1x512] + (25): Parameter containing: [torch.float32 of size 1x512] + (26): Parameter containing: [torch.float32 of size 1x512] + (27): Parameter containing: [torch.float32 of size 1x512] + (28): Parameter containing: [torch.float32 of size 1x512] + (29): Parameter containing: [torch.float32 of size 1x512] + (30): Parameter containing: [torch.float32 of size 1x512] + (31): Parameter containing: [torch.float32 of size 1x512] + (32): Parameter containing: [torch.float32 of size 1x512] + (33): Parameter containing: [torch.float32 of size 1x512] + (34): Parameter containing: [torch.float32 of size 1x512] + (35): Parameter containing: [torch.float32 of size 1x512] + (36): Parameter containing: [torch.float32 of size 1x512] + (37): Parameter containing: [torch.float32 of size 1x512] + (38): Parameter containing: [torch.float32 of size 1x512] + (39): Parameter containing: [torch.float32 of size 1x512] + (40): Parameter containing: [torch.float32 of size 1x512] + (41): Parameter containing: [torch.float32 of size 1x512] + (42): Parameter containing: [torch.float32 of size 1x512] + (43): Parameter containing: [torch.float32 of size 1x512] + (44): Parameter containing: [torch.float32 of size 1x512] + ) + (day_layer_dropout): Dropout(p=0.2, inplace=False) + (gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4) + (out): Linear(in_features=768, out_features=41, bias=True) + ) + (noisy_speech_model): NoisySpeechModel( + (gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4) + (out): Linear(in_features=768, out_features=41, bias=True) + ) +) +2025-10-12 09:33:20,124: Model has 687,568,466 parameters +2025-10-12 09:33:20,124: Model has 23,639,040 day-specific parameters | 3.44% of total parameters