f
This commit is contained in:
@@ -10,7 +10,8 @@
|
||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" examine_dataset_structure.py)",
|
||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" debug_alignment_step_by_step.py)",
|
||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 2)",
|
||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 3)"
|
||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 3)",
|
||||
"Bash(del \"f:\\BRAIN-TO-TEXT\\nejm-brain-to-text.worktrees\\dev2\\test_padding_fix.py\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@ model_training_lstm/trained_models_history
|
||||
*.pkl
|
||||
|
||||
.idea
|
||||
.claude
|
@@ -773,16 +773,35 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
'max_time_steps': 0,
|
||||
'max_phone_seq_len': 0,
|
||||
'max_transcription_len': 0,
|
||||
'n_features': 512 # Fixed for neural features
|
||||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||
}
|
||||
|
||||
# Sample a subset of data to determine max sizes
|
||||
count = 0
|
||||
for batch_idx in list(dataset_tf.batch_indices.keys())[:min(10, len(dataset_tf.batch_indices))]:
|
||||
batch_keys = list(dataset_tf.batch_indices.keys())
|
||||
|
||||
# If sample_size is -1, analyze all data
|
||||
if sample_size == -1:
|
||||
batches_to_check = batch_keys
|
||||
max_trials_per_batch = float('inf')
|
||||
else:
|
||||
# Sample a reasonable number of batches
|
||||
batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))]
|
||||
max_trials_per_batch = max(1, sample_size // len(batches_to_check))
|
||||
|
||||
for batch_idx in batches_to_check:
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
|
||||
batch_index = dataset_tf.batch_indices[batch_idx]
|
||||
|
||||
for day in batch_index.keys():
|
||||
for trial in batch_index[day][:min(10, len(batch_index[day]))]:
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
|
||||
trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))]
|
||||
|
||||
for trial in trials_to_check:
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
|
||||
@@ -803,24 +822,28 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
|
||||
count += 1
|
||||
|
||||
# Show progress for large analyses
|
||||
if count % 50 == 0:
|
||||
print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||
continue
|
||||
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
|
||||
# Add safety margins (20% buffer) to handle edge cases
|
||||
original_time_steps = max_shapes['max_time_steps']
|
||||
original_phone_seq_len = max_shapes['max_phone_seq_len']
|
||||
original_transcription_len = max_shapes['max_transcription_len']
|
||||
|
||||
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
|
||||
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
|
||||
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
|
||||
|
||||
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
|
||||
print(f" Max time steps: {max_shapes['max_time_steps']}")
|
||||
print(f" Max phone sequence length: {max_shapes['max_phone_seq_len']}")
|
||||
print(f" Max transcription length: {max_shapes['max_transcription_len']}")
|
||||
print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}")
|
||||
print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}")
|
||||
print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}")
|
||||
print(f" Number of features: {max_shapes['n_features']}")
|
||||
|
||||
return max_shapes
|
||||
|
||||
@@ -899,14 +922,39 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
max_phone_seq_len = shape_info['max_phone_seq_len']
|
||||
max_transcription_len = shape_info['max_transcription_len']
|
||||
n_features = shape_info['n_features']
|
||||
print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}")
|
||||
else:
|
||||
# Use conservative fixed shapes for TPU compatibility
|
||||
# Increased sizes to handle larger data - adjust based on your actual dataset
|
||||
max_time_steps = 8192 # Increased from 4096 - adjust based on your data
|
||||
max_phone_seq_len = 512 # Increased from 256 - adjust based on your data
|
||||
max_transcription_len = 1024 # Increased from 512 - adjust based on your data
|
||||
n_features = 512 # Number of neural features
|
||||
# Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically
|
||||
# This avoids the "pad to a smaller size" error by allowing dynamic sizing
|
||||
print(f"🔧 Using dynamic shapes for maximum compatibility")
|
||||
# Calculate number of features based on subset
|
||||
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||
|
||||
padded_shapes = {
|
||||
'input_features': tf.TensorShape([None, n_features]),
|
||||
'seq_class_ids': tf.TensorShape([None]),
|
||||
'n_time_steps': tf.TensorShape([]), # Scalar
|
||||
'phone_seq_lens': tf.TensorShape([]), # Scalar
|
||||
'day_indices': tf.TensorShape([]), # Scalar
|
||||
'transcriptions': tf.TensorShape([None]),
|
||||
'block_nums': tf.TensorShape([]), # Scalar
|
||||
'trial_nums': tf.TensorShape([]) # Scalar
|
||||
}
|
||||
|
||||
# Create fixed-shape batches with dynamic padding
|
||||
dataset = dataset.padded_batch(
|
||||
batch_size=dataset_tf.batch_size,
|
||||
padded_shapes=padded_shapes,
|
||||
padding_values=padding_values,
|
||||
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||||
)
|
||||
|
||||
# Prefetch for optimal performance
|
||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
return dataset
|
||||
|
||||
# If using auto-analyzed shapes, create fixed-size padded shapes
|
||||
padded_shapes = {
|
||||
'input_features': [max_time_steps, n_features],
|
||||
'seq_class_ids': [max_phone_seq_len],
|
||||
|
Reference in New Issue
Block a user