diff --git a/.claude/settings.local.json b/.claude/settings.local.json index fb44bb3..6a57db4 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -10,7 +10,8 @@ "Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" examine_dataset_structure.py)", "Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" debug_alignment_step_by_step.py)", "Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 2)", - "Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 3)" + "Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 3)", + "Bash(del \"f:\\BRAIN-TO-TEXT\\nejm-brain-to-text.worktrees\\dev2\\test_padding_fix.py\")" ], "deny": [], "ask": [] diff --git a/.gitignore b/.gitignore index c0f2dde..59ef2fa 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ model_training_lstm/trained_models_history *.pkl -.idea \ No newline at end of file +.idea +.claude \ No newline at end of file diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 5fbc977..2b2fd53 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -773,16 +773,35 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 'max_time_steps': 0, 'max_phone_seq_len': 0, 'max_transcription_len': 0, - 'n_features': 512 # Fixed for neural features + 'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 } # Sample a subset of data to determine max sizes count = 0 - for batch_idx in list(dataset_tf.batch_indices.keys())[:min(10, len(dataset_tf.batch_indices))]: + batch_keys = list(dataset_tf.batch_indices.keys()) + + # If sample_size is -1, analyze all data + if sample_size == -1: + batches_to_check = batch_keys + max_trials_per_batch = float('inf') + else: + # Sample a reasonable number of batches + batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))] + max_trials_per_batch = max(1, sample_size // len(batches_to_check)) + + for batch_idx in batches_to_check: + if count >= sample_size and sample_size > 0: + break + batch_index = dataset_tf.batch_indices[batch_idx] for day in batch_index.keys(): - for trial in batch_index[day][:min(10, len(batch_index[day]))]: + if count >= sample_size and sample_size > 0: + break + + trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))] + + for trial in trials_to_check: if count >= sample_size and sample_size > 0: break @@ -803,24 +822,28 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = count += 1 + # Show progress for large analyses + if count % 50 == 0: + print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}") + except Exception as e: logging.warning(f"Failed to analyze trial {day}_{trial}: {e}") continue - if count >= sample_size and sample_size > 0: - break - if count >= sample_size and sample_size > 0: - break - # Add safety margins (20% buffer) to handle edge cases + original_time_steps = max_shapes['max_time_steps'] + original_phone_seq_len = max_shapes['max_phone_seq_len'] + original_transcription_len = max_shapes['max_transcription_len'] + max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2) max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2) max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2) print(f"📊 Dataset analysis complete (analyzed {count} samples):") - print(f" Max time steps: {max_shapes['max_time_steps']}") - print(f" Max phone sequence length: {max_shapes['max_phone_seq_len']}") - print(f" Max transcription length: {max_shapes['max_transcription_len']}") + print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}") + print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}") + print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}") + print(f" Number of features: {max_shapes['n_features']}") return max_shapes @@ -899,14 +922,39 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, max_phone_seq_len = shape_info['max_phone_seq_len'] max_transcription_len = shape_info['max_transcription_len'] n_features = shape_info['n_features'] + print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}") else: - # Use conservative fixed shapes for TPU compatibility - # Increased sizes to handle larger data - adjust based on your actual dataset - max_time_steps = 8192 # Increased from 4096 - adjust based on your data - max_phone_seq_len = 512 # Increased from 256 - adjust based on your data - max_transcription_len = 1024 # Increased from 512 - adjust based on your data - n_features = 512 # Number of neural features + # Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically + # This avoids the "pad to a smaller size" error by allowing dynamic sizing + print(f"🔧 Using dynamic shapes for maximum compatibility") + # Calculate number of features based on subset + n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 + padded_shapes = { + 'input_features': tf.TensorShape([None, n_features]), + 'seq_class_ids': tf.TensorShape([None]), + 'n_time_steps': tf.TensorShape([]), # Scalar + 'phone_seq_lens': tf.TensorShape([]), # Scalar + 'day_indices': tf.TensorShape([]), # Scalar + 'transcriptions': tf.TensorShape([None]), + 'block_nums': tf.TensorShape([]), # Scalar + 'trial_nums': tf.TensorShape([]) # Scalar + } + + # Create fixed-shape batches with dynamic padding + dataset = dataset.padded_batch( + batch_size=dataset_tf.batch_size, + padded_shapes=padded_shapes, + padding_values=padding_values, + drop_remainder=True # Critical for TPU: ensures all batches have same size + ) + + # Prefetch for optimal performance + dataset = dataset.prefetch(tf.data.AUTOTUNE) + + return dataset + + # If using auto-analyzed shapes, create fixed-size padded shapes padded_shapes = { 'input_features': [max_time_steps, n_features], 'seq_class_ids': [max_phone_seq_len],