f
This commit is contained in:
@@ -10,7 +10,8 @@
|
|||||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" examine_dataset_structure.py)",
|
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" examine_dataset_structure.py)",
|
||||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" debug_alignment_step_by_step.py)",
|
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" debug_alignment_step_by_step.py)",
|
||||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 2)",
|
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 2)",
|
||||||
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 3)"
|
"Bash(\"D:/SoftWare/Anaconda3/envs/b2txt25/python.exe\" CRNN_pretag.py --max_sessions 3)",
|
||||||
|
"Bash(del \"f:\\BRAIN-TO-TEXT\\nejm-brain-to-text.worktrees\\dev2\\test_padding_fix.py\")"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": []
|
"ask": []
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@ model_training_lstm/trained_models_history
|
|||||||
*.pkl
|
*.pkl
|
||||||
|
|
||||||
.idea
|
.idea
|
||||||
|
.claude
|
@@ -773,16 +773,35 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
'max_time_steps': 0,
|
'max_time_steps': 0,
|
||||||
'max_phone_seq_len': 0,
|
'max_phone_seq_len': 0,
|
||||||
'max_transcription_len': 0,
|
'max_transcription_len': 0,
|
||||||
'n_features': 512 # Fixed for neural features
|
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sample a subset of data to determine max sizes
|
# Sample a subset of data to determine max sizes
|
||||||
count = 0
|
count = 0
|
||||||
for batch_idx in list(dataset_tf.batch_indices.keys())[:min(10, len(dataset_tf.batch_indices))]:
|
batch_keys = list(dataset_tf.batch_indices.keys())
|
||||||
|
|
||||||
|
# If sample_size is -1, analyze all data
|
||||||
|
if sample_size == -1:
|
||||||
|
batches_to_check = batch_keys
|
||||||
|
max_trials_per_batch = float('inf')
|
||||||
|
else:
|
||||||
|
# Sample a reasonable number of batches
|
||||||
|
batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))]
|
||||||
|
max_trials_per_batch = max(1, sample_size // len(batches_to_check))
|
||||||
|
|
||||||
|
for batch_idx in batches_to_check:
|
||||||
|
if count >= sample_size and sample_size > 0:
|
||||||
|
break
|
||||||
|
|
||||||
batch_index = dataset_tf.batch_indices[batch_idx]
|
batch_index = dataset_tf.batch_indices[batch_idx]
|
||||||
|
|
||||||
for day in batch_index.keys():
|
for day in batch_index.keys():
|
||||||
for trial in batch_index[day][:min(10, len(batch_index[day]))]:
|
if count >= sample_size and sample_size > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))]
|
||||||
|
|
||||||
|
for trial in trials_to_check:
|
||||||
if count >= sample_size and sample_size > 0:
|
if count >= sample_size and sample_size > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -803,24 +822,28 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
|
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
# Show progress for large analyses
|
||||||
|
if count % 50 == 0:
|
||||||
|
print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if count >= sample_size and sample_size > 0:
|
|
||||||
break
|
|
||||||
if count >= sample_size and sample_size > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Add safety margins (20% buffer) to handle edge cases
|
# Add safety margins (20% buffer) to handle edge cases
|
||||||
|
original_time_steps = max_shapes['max_time_steps']
|
||||||
|
original_phone_seq_len = max_shapes['max_phone_seq_len']
|
||||||
|
original_transcription_len = max_shapes['max_transcription_len']
|
||||||
|
|
||||||
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
|
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
|
||||||
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
|
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
|
||||||
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
|
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
|
||||||
|
|
||||||
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
|
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
|
||||||
print(f" Max time steps: {max_shapes['max_time_steps']}")
|
print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}")
|
||||||
print(f" Max phone sequence length: {max_shapes['max_phone_seq_len']}")
|
print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}")
|
||||||
print(f" Max transcription length: {max_shapes['max_transcription_len']}")
|
print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}")
|
||||||
|
print(f" Number of features: {max_shapes['n_features']}")
|
||||||
|
|
||||||
return max_shapes
|
return max_shapes
|
||||||
|
|
||||||
@@ -899,14 +922,39 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
max_phone_seq_len = shape_info['max_phone_seq_len']
|
max_phone_seq_len = shape_info['max_phone_seq_len']
|
||||||
max_transcription_len = shape_info['max_transcription_len']
|
max_transcription_len = shape_info['max_transcription_len']
|
||||||
n_features = shape_info['n_features']
|
n_features = shape_info['n_features']
|
||||||
|
print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}")
|
||||||
else:
|
else:
|
||||||
# Use conservative fixed shapes for TPU compatibility
|
# Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically
|
||||||
# Increased sizes to handle larger data - adjust based on your actual dataset
|
# This avoids the "pad to a smaller size" error by allowing dynamic sizing
|
||||||
max_time_steps = 8192 # Increased from 4096 - adjust based on your data
|
print(f"🔧 Using dynamic shapes for maximum compatibility")
|
||||||
max_phone_seq_len = 512 # Increased from 256 - adjust based on your data
|
# Calculate number of features based on subset
|
||||||
max_transcription_len = 1024 # Increased from 512 - adjust based on your data
|
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||||
n_features = 512 # Number of neural features
|
|
||||||
|
|
||||||
|
padded_shapes = {
|
||||||
|
'input_features': tf.TensorShape([None, n_features]),
|
||||||
|
'seq_class_ids': tf.TensorShape([None]),
|
||||||
|
'n_time_steps': tf.TensorShape([]), # Scalar
|
||||||
|
'phone_seq_lens': tf.TensorShape([]), # Scalar
|
||||||
|
'day_indices': tf.TensorShape([]), # Scalar
|
||||||
|
'transcriptions': tf.TensorShape([None]),
|
||||||
|
'block_nums': tf.TensorShape([]), # Scalar
|
||||||
|
'trial_nums': tf.TensorShape([]) # Scalar
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create fixed-shape batches with dynamic padding
|
||||||
|
dataset = dataset.padded_batch(
|
||||||
|
batch_size=dataset_tf.batch_size,
|
||||||
|
padded_shapes=padded_shapes,
|
||||||
|
padding_values=padding_values,
|
||||||
|
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefetch for optimal performance
|
||||||
|
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
# If using auto-analyzed shapes, create fixed-size padded shapes
|
||||||
padded_shapes = {
|
padded_shapes = {
|
||||||
'input_features': [max_time_steps, n_features],
|
'input_features': [max_time_steps, n_features],
|
||||||
'seq_class_ids': [max_phone_seq_len],
|
'seq_class_ids': [max_phone_seq_len],
|
||||||
|
Reference in New Issue
Block a user