This commit is contained in:
Zchen
2025-10-19 13:18:20 +08:00
parent 4328114ed6
commit 40d0fc50de
3 changed files with 69 additions and 19 deletions

View File

@@ -10,7 +10,8 @@
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" examine_dataset_structure.py)",
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" debug_alignment_step_by_step.py)",
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 2)",
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 3)"
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 3)",
"Bash(del \"f:\\BRAIN-TO-TEXT\\nejm-brain-to-text.worktrees\\dev2\\test_padding_fix.py\")"
],
"deny": [],
"ask": []

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@ model_training_lstm/trained_models_history
*.pkl
.idea
.claude

View File

@@ -773,16 +773,35 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
'max_time_steps': 0,
'max_phone_seq_len': 0,
'max_transcription_len': 0,
'n_features': 512 # Fixed for neural features
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
}
# Sample a subset of data to determine max sizes
count = 0
for batch_idx in list(dataset_tf.batch_indices.keys())[:min(10, len(dataset_tf.batch_indices))]:
batch_keys = list(dataset_tf.batch_indices.keys())
# If sample_size is -1, analyze all data
if sample_size == -1:
batches_to_check = batch_keys
max_trials_per_batch = float('inf')
else:
# Sample a reasonable number of batches
batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))]
max_trials_per_batch = max(1, sample_size // len(batches_to_check))
for batch_idx in batches_to_check:
if count >= sample_size and sample_size > 0:
break
batch_index = dataset_tf.batch_indices[batch_idx]
for day in batch_index.keys():
for trial in batch_index[day][:min(10, len(batch_index[day]))]:
if count >= sample_size and sample_size > 0:
break
trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))]
for trial in trials_to_check:
if count >= sample_size and sample_size > 0:
break
@@ -803,24 +822,28 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
count += 1
# Show progress for large analyses
if count % 50 == 0:
print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
continue
if count >= sample_size and sample_size > 0:
break
if count >= sample_size and sample_size > 0:
break
# Add safety margins (20% buffer) to handle edge cases
original_time_steps = max_shapes['max_time_steps']
original_phone_seq_len = max_shapes['max_phone_seq_len']
original_transcription_len = max_shapes['max_transcription_len']
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
print(f" Max time steps: {max_shapes['max_time_steps']}")
print(f" Max phone sequence length: {max_shapes['max_phone_seq_len']}")
print(f" Max transcription length: {max_shapes['max_transcription_len']}")
print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}")
print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}")
print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}")
print(f" Number of features: {max_shapes['n_features']}")
return max_shapes
@@ -899,14 +922,39 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
max_phone_seq_len = shape_info['max_phone_seq_len']
max_transcription_len = shape_info['max_transcription_len']
n_features = shape_info['n_features']
print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}")
else:
# Use conservative fixed shapes for TPU compatibility
# Increased sizes to handle larger data - adjust based on your actual dataset
max_time_steps = 8192 # Increased from 4096 - adjust based on your data
max_phone_seq_len = 512 # Increased from 256 - adjust based on your data
max_transcription_len = 1024 # Increased from 512 - adjust based on your data
n_features = 512 # Number of neural features
# Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically
# This avoids the "pad to a smaller size" error by allowing dynamic sizing
print(f"🔧 Using dynamic shapes for maximum compatibility")
# Calculate number of features based on subset
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
padded_shapes = {
'input_features': tf.TensorShape([None, n_features]),
'seq_class_ids': tf.TensorShape([None]),
'n_time_steps': tf.TensorShape([]), # Scalar
'phone_seq_lens': tf.TensorShape([]), # Scalar
'day_indices': tf.TensorShape([]), # Scalar
'transcriptions': tf.TensorShape([None]),
'block_nums': tf.TensorShape([]), # Scalar
'trial_nums': tf.TensorShape([]) # Scalar
}
# Create fixed-shape batches with dynamic padding
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
padding_values=padding_values,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
# Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
# If using auto-analyzed shapes, create fixed-size padded shapes
padded_shapes = {
'input_features': [max_time_steps, n_features],
'seq_class_ids': [max_phone_seq_len],