beifen
This commit is contained in:
@@ -22,7 +22,7 @@ torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on som
|
||||
torch.backends.cudnn.deterministic = True # makes training more reproducible
|
||||
torch._dynamo.config.cache_size_limit = 64
|
||||
|
||||
from rnn_model import GRUDecoder
|
||||
from rnn_model import TripleGRUDecoder
|
||||
|
||||
class BrainToTextDecoder_Trainer:
|
||||
"""
|
||||
@@ -116,22 +116,23 @@ class BrainToTextDecoder_Trainer:
|
||||
random.seed(self.args['seed'])
|
||||
torch.manual_seed(self.args['seed'])
|
||||
|
||||
# Initialize the model
|
||||
self.model = GRUDecoder(
|
||||
# Initialize the model
|
||||
self.model = TripleGRUDecoder(
|
||||
neural_dim = self.args['model']['n_input_features'],
|
||||
n_units = self.args['model']['n_units'],
|
||||
n_days = len(self.args['dataset']['sessions']),
|
||||
n_classes = self.args['dataset']['n_classes'],
|
||||
rnn_dropout = self.args['model']['rnn_dropout'],
|
||||
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
|
||||
n_layers = self.args['model']['n_layers'],
|
||||
rnn_dropout = self.args['model']['rnn_dropout'],
|
||||
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
|
||||
patch_size = self.args['model']['patch_size'],
|
||||
patch_stride = self.args['model']['patch_stride'],
|
||||
)
|
||||
|
||||
# Call torch.compile to speed up training
|
||||
self.logger.info("Using torch.compile")
|
||||
self.model = torch.compile(self.model)
|
||||
# Temporarily disable torch.compile for compatibility with new model architecture
|
||||
# TODO: Re-enable torch.compile once model is stable
|
||||
# self.logger.info("Using torch.compile")
|
||||
# self.model = torch.compile(self.model)
|
||||
self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
|
||||
|
||||
self.logger.info(f"Initialized RNN decoding model")
|
||||
|
||||
@@ -531,8 +532,9 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies)
|
||||
# Get phoneme predictions using inference mode during training
|
||||
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
# Calculate CTC Loss
|
||||
loss = self.ctc_loss(
|
||||
@@ -706,7 +708,7 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
logits = self.model(features, day_indicies)
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
loss = self.ctc_loss(
|
||||
torch.permute(logits.log_softmax(2), [1, 0, 2]),
|
||||
|
213
model_training_nnn/trained_models/baseline_rnn/training_log
Normal file
213
model_training_nnn/trained_models/baseline_rnn/training_log
Normal file
@@ -0,0 +1,213 @@
|
||||
2025-10-12 09:32:47,330: Using device: cpu
|
||||
2025-10-12 09:33:20,119: torch.compile disabled for new TripleGRUDecoder compatibility
|
||||
2025-10-12 09:33:20,120: Initialized RNN decoding model
|
||||
2025-10-12 09:33:20,120: Model type: <class 'rnn_model.TripleGRUDecoder'>
|
||||
2025-10-12 09:33:20,120: Model is callable: True
|
||||
2025-10-12 09:33:20,120: Model has forward method: True
|
||||
2025-10-12 09:33:20,120: TripleGRUDecoder(
|
||||
(noise_model): NoiseModel(
|
||||
(day_layer_activation): Softsign()
|
||||
(day_weights): ParameterList(
|
||||
(0): Parameter containing: [torch.float32 of size 512x512]
|
||||
(1): Parameter containing: [torch.float32 of size 512x512]
|
||||
(2): Parameter containing: [torch.float32 of size 512x512]
|
||||
(3): Parameter containing: [torch.float32 of size 512x512]
|
||||
(4): Parameter containing: [torch.float32 of size 512x512]
|
||||
(5): Parameter containing: [torch.float32 of size 512x512]
|
||||
(6): Parameter containing: [torch.float32 of size 512x512]
|
||||
(7): Parameter containing: [torch.float32 of size 512x512]
|
||||
(8): Parameter containing: [torch.float32 of size 512x512]
|
||||
(9): Parameter containing: [torch.float32 of size 512x512]
|
||||
(10): Parameter containing: [torch.float32 of size 512x512]
|
||||
(11): Parameter containing: [torch.float32 of size 512x512]
|
||||
(12): Parameter containing: [torch.float32 of size 512x512]
|
||||
(13): Parameter containing: [torch.float32 of size 512x512]
|
||||
(14): Parameter containing: [torch.float32 of size 512x512]
|
||||
(15): Parameter containing: [torch.float32 of size 512x512]
|
||||
(16): Parameter containing: [torch.float32 of size 512x512]
|
||||
(17): Parameter containing: [torch.float32 of size 512x512]
|
||||
(18): Parameter containing: [torch.float32 of size 512x512]
|
||||
(19): Parameter containing: [torch.float32 of size 512x512]
|
||||
(20): Parameter containing: [torch.float32 of size 512x512]
|
||||
(21): Parameter containing: [torch.float32 of size 512x512]
|
||||
(22): Parameter containing: [torch.float32 of size 512x512]
|
||||
(23): Parameter containing: [torch.float32 of size 512x512]
|
||||
(24): Parameter containing: [torch.float32 of size 512x512]
|
||||
(25): Parameter containing: [torch.float32 of size 512x512]
|
||||
(26): Parameter containing: [torch.float32 of size 512x512]
|
||||
(27): Parameter containing: [torch.float32 of size 512x512]
|
||||
(28): Parameter containing: [torch.float32 of size 512x512]
|
||||
(29): Parameter containing: [torch.float32 of size 512x512]
|
||||
(30): Parameter containing: [torch.float32 of size 512x512]
|
||||
(31): Parameter containing: [torch.float32 of size 512x512]
|
||||
(32): Parameter containing: [torch.float32 of size 512x512]
|
||||
(33): Parameter containing: [torch.float32 of size 512x512]
|
||||
(34): Parameter containing: [torch.float32 of size 512x512]
|
||||
(35): Parameter containing: [torch.float32 of size 512x512]
|
||||
(36): Parameter containing: [torch.float32 of size 512x512]
|
||||
(37): Parameter containing: [torch.float32 of size 512x512]
|
||||
(38): Parameter containing: [torch.float32 of size 512x512]
|
||||
(39): Parameter containing: [torch.float32 of size 512x512]
|
||||
(40): Parameter containing: [torch.float32 of size 512x512]
|
||||
(41): Parameter containing: [torch.float32 of size 512x512]
|
||||
(42): Parameter containing: [torch.float32 of size 512x512]
|
||||
(43): Parameter containing: [torch.float32 of size 512x512]
|
||||
(44): Parameter containing: [torch.float32 of size 512x512]
|
||||
)
|
||||
(day_biases): ParameterList(
|
||||
(0): Parameter containing: [torch.float32 of size 1x512]
|
||||
(1): Parameter containing: [torch.float32 of size 1x512]
|
||||
(2): Parameter containing: [torch.float32 of size 1x512]
|
||||
(3): Parameter containing: [torch.float32 of size 1x512]
|
||||
(4): Parameter containing: [torch.float32 of size 1x512]
|
||||
(5): Parameter containing: [torch.float32 of size 1x512]
|
||||
(6): Parameter containing: [torch.float32 of size 1x512]
|
||||
(7): Parameter containing: [torch.float32 of size 1x512]
|
||||
(8): Parameter containing: [torch.float32 of size 1x512]
|
||||
(9): Parameter containing: [torch.float32 of size 1x512]
|
||||
(10): Parameter containing: [torch.float32 of size 1x512]
|
||||
(11): Parameter containing: [torch.float32 of size 1x512]
|
||||
(12): Parameter containing: [torch.float32 of size 1x512]
|
||||
(13): Parameter containing: [torch.float32 of size 1x512]
|
||||
(14): Parameter containing: [torch.float32 of size 1x512]
|
||||
(15): Parameter containing: [torch.float32 of size 1x512]
|
||||
(16): Parameter containing: [torch.float32 of size 1x512]
|
||||
(17): Parameter containing: [torch.float32 of size 1x512]
|
||||
(18): Parameter containing: [torch.float32 of size 1x512]
|
||||
(19): Parameter containing: [torch.float32 of size 1x512]
|
||||
(20): Parameter containing: [torch.float32 of size 1x512]
|
||||
(21): Parameter containing: [torch.float32 of size 1x512]
|
||||
(22): Parameter containing: [torch.float32 of size 1x512]
|
||||
(23): Parameter containing: [torch.float32 of size 1x512]
|
||||
(24): Parameter containing: [torch.float32 of size 1x512]
|
||||
(25): Parameter containing: [torch.float32 of size 1x512]
|
||||
(26): Parameter containing: [torch.float32 of size 1x512]
|
||||
(27): Parameter containing: [torch.float32 of size 1x512]
|
||||
(28): Parameter containing: [torch.float32 of size 1x512]
|
||||
(29): Parameter containing: [torch.float32 of size 1x512]
|
||||
(30): Parameter containing: [torch.float32 of size 1x512]
|
||||
(31): Parameter containing: [torch.float32 of size 1x512]
|
||||
(32): Parameter containing: [torch.float32 of size 1x512]
|
||||
(33): Parameter containing: [torch.float32 of size 1x512]
|
||||
(34): Parameter containing: [torch.float32 of size 1x512]
|
||||
(35): Parameter containing: [torch.float32 of size 1x512]
|
||||
(36): Parameter containing: [torch.float32 of size 1x512]
|
||||
(37): Parameter containing: [torch.float32 of size 1x512]
|
||||
(38): Parameter containing: [torch.float32 of size 1x512]
|
||||
(39): Parameter containing: [torch.float32 of size 1x512]
|
||||
(40): Parameter containing: [torch.float32 of size 1x512]
|
||||
(41): Parameter containing: [torch.float32 of size 1x512]
|
||||
(42): Parameter containing: [torch.float32 of size 1x512]
|
||||
(43): Parameter containing: [torch.float32 of size 1x512]
|
||||
(44): Parameter containing: [torch.float32 of size 1x512]
|
||||
)
|
||||
(day_layer_dropout): Dropout(p=0.2, inplace=False)
|
||||
(gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4)
|
||||
)
|
||||
(clean_speech_model): CleanSpeechModel(
|
||||
(day_layer_activation): Softsign()
|
||||
(day_weights): ParameterList(
|
||||
(0): Parameter containing: [torch.float32 of size 512x512]
|
||||
(1): Parameter containing: [torch.float32 of size 512x512]
|
||||
(2): Parameter containing: [torch.float32 of size 512x512]
|
||||
(3): Parameter containing: [torch.float32 of size 512x512]
|
||||
(4): Parameter containing: [torch.float32 of size 512x512]
|
||||
(5): Parameter containing: [torch.float32 of size 512x512]
|
||||
(6): Parameter containing: [torch.float32 of size 512x512]
|
||||
(7): Parameter containing: [torch.float32 of size 512x512]
|
||||
(8): Parameter containing: [torch.float32 of size 512x512]
|
||||
(9): Parameter containing: [torch.float32 of size 512x512]
|
||||
(10): Parameter containing: [torch.float32 of size 512x512]
|
||||
(11): Parameter containing: [torch.float32 of size 512x512]
|
||||
(12): Parameter containing: [torch.float32 of size 512x512]
|
||||
(13): Parameter containing: [torch.float32 of size 512x512]
|
||||
(14): Parameter containing: [torch.float32 of size 512x512]
|
||||
(15): Parameter containing: [torch.float32 of size 512x512]
|
||||
(16): Parameter containing: [torch.float32 of size 512x512]
|
||||
(17): Parameter containing: [torch.float32 of size 512x512]
|
||||
(18): Parameter containing: [torch.float32 of size 512x512]
|
||||
(19): Parameter containing: [torch.float32 of size 512x512]
|
||||
(20): Parameter containing: [torch.float32 of size 512x512]
|
||||
(21): Parameter containing: [torch.float32 of size 512x512]
|
||||
(22): Parameter containing: [torch.float32 of size 512x512]
|
||||
(23): Parameter containing: [torch.float32 of size 512x512]
|
||||
(24): Parameter containing: [torch.float32 of size 512x512]
|
||||
(25): Parameter containing: [torch.float32 of size 512x512]
|
||||
(26): Parameter containing: [torch.float32 of size 512x512]
|
||||
(27): Parameter containing: [torch.float32 of size 512x512]
|
||||
(28): Parameter containing: [torch.float32 of size 512x512]
|
||||
(29): Parameter containing: [torch.float32 of size 512x512]
|
||||
(30): Parameter containing: [torch.float32 of size 512x512]
|
||||
(31): Parameter containing: [torch.float32 of size 512x512]
|
||||
(32): Parameter containing: [torch.float32 of size 512x512]
|
||||
(33): Parameter containing: [torch.float32 of size 512x512]
|
||||
(34): Parameter containing: [torch.float32 of size 512x512]
|
||||
(35): Parameter containing: [torch.float32 of size 512x512]
|
||||
(36): Parameter containing: [torch.float32 of size 512x512]
|
||||
(37): Parameter containing: [torch.float32 of size 512x512]
|
||||
(38): Parameter containing: [torch.float32 of size 512x512]
|
||||
(39): Parameter containing: [torch.float32 of size 512x512]
|
||||
(40): Parameter containing: [torch.float32 of size 512x512]
|
||||
(41): Parameter containing: [torch.float32 of size 512x512]
|
||||
(42): Parameter containing: [torch.float32 of size 512x512]
|
||||
(43): Parameter containing: [torch.float32 of size 512x512]
|
||||
(44): Parameter containing: [torch.float32 of size 512x512]
|
||||
)
|
||||
(day_biases): ParameterList(
|
||||
(0): Parameter containing: [torch.float32 of size 1x512]
|
||||
(1): Parameter containing: [torch.float32 of size 1x512]
|
||||
(2): Parameter containing: [torch.float32 of size 1x512]
|
||||
(3): Parameter containing: [torch.float32 of size 1x512]
|
||||
(4): Parameter containing: [torch.float32 of size 1x512]
|
||||
(5): Parameter containing: [torch.float32 of size 1x512]
|
||||
(6): Parameter containing: [torch.float32 of size 1x512]
|
||||
(7): Parameter containing: [torch.float32 of size 1x512]
|
||||
(8): Parameter containing: [torch.float32 of size 1x512]
|
||||
(9): Parameter containing: [torch.float32 of size 1x512]
|
||||
(10): Parameter containing: [torch.float32 of size 1x512]
|
||||
(11): Parameter containing: [torch.float32 of size 1x512]
|
||||
(12): Parameter containing: [torch.float32 of size 1x512]
|
||||
(13): Parameter containing: [torch.float32 of size 1x512]
|
||||
(14): Parameter containing: [torch.float32 of size 1x512]
|
||||
(15): Parameter containing: [torch.float32 of size 1x512]
|
||||
(16): Parameter containing: [torch.float32 of size 1x512]
|
||||
(17): Parameter containing: [torch.float32 of size 1x512]
|
||||
(18): Parameter containing: [torch.float32 of size 1x512]
|
||||
(19): Parameter containing: [torch.float32 of size 1x512]
|
||||
(20): Parameter containing: [torch.float32 of size 1x512]
|
||||
(21): Parameter containing: [torch.float32 of size 1x512]
|
||||
(22): Parameter containing: [torch.float32 of size 1x512]
|
||||
(23): Parameter containing: [torch.float32 of size 1x512]
|
||||
(24): Parameter containing: [torch.float32 of size 1x512]
|
||||
(25): Parameter containing: [torch.float32 of size 1x512]
|
||||
(26): Parameter containing: [torch.float32 of size 1x512]
|
||||
(27): Parameter containing: [torch.float32 of size 1x512]
|
||||
(28): Parameter containing: [torch.float32 of size 1x512]
|
||||
(29): Parameter containing: [torch.float32 of size 1x512]
|
||||
(30): Parameter containing: [torch.float32 of size 1x512]
|
||||
(31): Parameter containing: [torch.float32 of size 1x512]
|
||||
(32): Parameter containing: [torch.float32 of size 1x512]
|
||||
(33): Parameter containing: [torch.float32 of size 1x512]
|
||||
(34): Parameter containing: [torch.float32 of size 1x512]
|
||||
(35): Parameter containing: [torch.float32 of size 1x512]
|
||||
(36): Parameter containing: [torch.float32 of size 1x512]
|
||||
(37): Parameter containing: [torch.float32 of size 1x512]
|
||||
(38): Parameter containing: [torch.float32 of size 1x512]
|
||||
(39): Parameter containing: [torch.float32 of size 1x512]
|
||||
(40): Parameter containing: [torch.float32 of size 1x512]
|
||||
(41): Parameter containing: [torch.float32 of size 1x512]
|
||||
(42): Parameter containing: [torch.float32 of size 1x512]
|
||||
(43): Parameter containing: [torch.float32 of size 1x512]
|
||||
(44): Parameter containing: [torch.float32 of size 1x512]
|
||||
)
|
||||
(day_layer_dropout): Dropout(p=0.2, inplace=False)
|
||||
(gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4)
|
||||
(out): Linear(in_features=768, out_features=41, bias=True)
|
||||
)
|
||||
(noisy_speech_model): NoisySpeechModel(
|
||||
(gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4)
|
||||
(out): Linear(in_features=768, out_features=41, bias=True)
|
||||
)
|
||||
)
|
||||
2025-10-12 09:33:20,124: Model has 687,568,466 parameters
|
||||
2025-10-12 09:33:20,124: Model has 23,639,040 day-specific parameters | 3.44% of total parameters
|
Reference in New Issue
Block a user