Remove test scripts for data loading and TensorFlow implementation
This commit is contained in:
@@ -1,254 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
测试优化后的数据加载管道性能
|
|
||||||
Test script for optimized data loading pipeline performance
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import psutil
|
|
||||||
import tensorflow as tf
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn
|
|
||||||
|
|
||||||
def get_memory_usage():
|
|
||||||
"""获取当前内存使用情况"""
|
|
||||||
process = psutil.Process()
|
|
||||||
memory_info = process.memory_info()
|
|
||||||
return memory_info.rss / 1024 / 1024 # MB
|
|
||||||
|
|
||||||
def test_data_loading_performance():
|
|
||||||
"""测试数据加载性能对比"""
|
|
||||||
|
|
||||||
# 加载配置
|
|
||||||
config_path = "../rnn_args.yaml"
|
|
||||||
if not os.path.exists(config_path):
|
|
||||||
print("❌ Configuration file not found. Creating minimal test config...")
|
|
||||||
# 创建最小测试配置
|
|
||||||
args = {
|
|
||||||
'dataset': {
|
|
||||||
'dataset_dir': '../data/hdf5_data_final',
|
|
||||||
'sessions': ['t15.2022.03.14', 't15.2022.03.16'],
|
|
||||||
'batch_size': 32,
|
|
||||||
'days_per_batch': 1,
|
|
||||||
'seed': 42,
|
|
||||||
'data_transforms': {
|
|
||||||
'smooth_data': False,
|
|
||||||
'white_noise_std': 0.0,
|
|
||||||
'constant_offset_std': 0.0,
|
|
||||||
'random_walk_std': 0.0,
|
|
||||||
'static_gain_std': 0.0,
|
|
||||||
'random_cut': 0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'num_training_batches': 10 # 只测试10个batch
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
args = OmegaConf.load(config_path)
|
|
||||||
args = OmegaConf.to_container(args, resolve=True)
|
|
||||||
# 限制测试batch数量
|
|
||||||
args['num_training_batches'] = 10
|
|
||||||
|
|
||||||
print("🔍 Starting data loading performance test...")
|
|
||||||
print(f"📊 Test configuration: {args['num_training_batches']} batches, batch_size={args['dataset']['batch_size']}")
|
|
||||||
|
|
||||||
# 获取文件路径
|
|
||||||
train_file_paths = [
|
|
||||||
os.path.join(args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
|
|
||||||
for s in args['dataset']['sessions']
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"📁 Testing with files: {train_file_paths}")
|
|
||||||
|
|
||||||
# 检查文件是否存在
|
|
||||||
missing_files = [f for f in train_file_paths if not os.path.exists(f)]
|
|
||||||
if missing_files:
|
|
||||||
print(f"❌ Missing files: {missing_files}")
|
|
||||||
print("⚠️ Creating dummy test data...")
|
|
||||||
return test_with_dummy_data(args)
|
|
||||||
|
|
||||||
# 分割数据
|
|
||||||
print("🔄 Splitting data...")
|
|
||||||
train_trials, _ = train_test_split_indices(
|
|
||||||
file_paths=train_file_paths,
|
|
||||||
test_percentage=0,
|
|
||||||
seed=args['dataset']['seed']
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"📈 Found {sum(len(trials['trials']) for trials in train_trials.values())} training trials")
|
|
||||||
|
|
||||||
# 测试1: 不使用缓存
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("🐌 TEST 1: 标准数据加载 (无缓存)")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
initial_memory = get_memory_usage()
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
dataset_no_cache = BrainToTextDatasetTF(
|
|
||||||
trial_indices=train_trials,
|
|
||||||
n_batches=args['num_training_batches'],
|
|
||||||
split='train',
|
|
||||||
batch_size=args['dataset']['batch_size'],
|
|
||||||
days_per_batch=args['dataset']['days_per_batch'],
|
|
||||||
random_seed=args['dataset']['seed'],
|
|
||||||
cache_data=False, # 禁用缓存
|
|
||||||
preload_all_data=False # 禁用预加载
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_dataset_no_cache = create_input_fn(
|
|
||||||
dataset_no_cache,
|
|
||||||
args['dataset']['data_transforms'],
|
|
||||||
training=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试前3个batch的加载时间
|
|
||||||
batch_times = []
|
|
||||||
for i, batch in enumerate(tf_dataset_no_cache.take(3)):
|
|
||||||
batch_start = time.time()
|
|
||||||
# 触发实际数据加载
|
|
||||||
_ = batch['input_features'].numpy()
|
|
||||||
batch_time = time.time() - batch_start
|
|
||||||
batch_times.append(batch_time)
|
|
||||||
print(f" Batch {i}: {batch_time:.3f}s")
|
|
||||||
|
|
||||||
no_cache_time = time.time() - start_time
|
|
||||||
no_cache_memory = get_memory_usage() - initial_memory
|
|
||||||
|
|
||||||
print(f"💾 Memory usage: +{no_cache_memory:.1f} MB")
|
|
||||||
print(f"⏱️ Total time: {no_cache_time:.3f}s")
|
|
||||||
print(f"📊 Avg batch time: {sum(batch_times)/len(batch_times):.3f}s")
|
|
||||||
|
|
||||||
# 测试2: 使用预加载缓存
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("🚀 TEST 2: 优化数据加载 (全缓存预加载)")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
initial_memory = get_memory_usage()
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
dataset_with_cache = BrainToTextDatasetTF(
|
|
||||||
trial_indices=train_trials,
|
|
||||||
n_batches=args['num_training_batches'],
|
|
||||||
split='train',
|
|
||||||
batch_size=args['dataset']['batch_size'],
|
|
||||||
days_per_batch=args['dataset']['days_per_batch'],
|
|
||||||
random_seed=args['dataset']['seed'],
|
|
||||||
cache_data=True, # 启用缓存
|
|
||||||
preload_all_data=True # 启用预加载
|
|
||||||
)
|
|
||||||
|
|
||||||
preload_time = time.time() - start_time
|
|
||||||
preload_memory = get_memory_usage() - initial_memory
|
|
||||||
|
|
||||||
print(f"📝 Preloading completed in {preload_time:.3f}s")
|
|
||||||
print(f"💾 Preloading memory: +{preload_memory:.1f} MB")
|
|
||||||
|
|
||||||
tf_dataset_with_cache = create_input_fn(
|
|
||||||
dataset_with_cache,
|
|
||||||
args['dataset']['data_transforms'],
|
|
||||||
training=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试前3个batch的加载时间
|
|
||||||
batch_start_time = time.time()
|
|
||||||
batch_times_cached = []
|
|
||||||
for i, batch in enumerate(tf_dataset_with_cache.take(3)):
|
|
||||||
batch_start = time.time()
|
|
||||||
# 触发实际数据加载
|
|
||||||
_ = batch['input_features'].numpy()
|
|
||||||
batch_time = time.time() - batch_start
|
|
||||||
batch_times_cached.append(batch_time)
|
|
||||||
print(f" Batch {i}: {batch_time:.3f}s")
|
|
||||||
|
|
||||||
cached_batch_time = time.time() - batch_start_time
|
|
||||||
cached_memory = get_memory_usage() - initial_memory
|
|
||||||
|
|
||||||
print(f"💾 Total memory usage: +{cached_memory:.1f} MB")
|
|
||||||
print(f"⏱️ Batch loading time: {cached_batch_time:.3f}s")
|
|
||||||
print(f"📊 Avg batch time: {sum(batch_times_cached)/len(batch_times_cached):.3f}s")
|
|
||||||
|
|
||||||
# 性能对比
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("📈 PERFORMANCE COMPARISON")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
speedup = (sum(batch_times)/len(batch_times)) / (sum(batch_times_cached)/len(batch_times_cached))
|
|
||||||
memory_cost = cached_memory - no_cache_memory
|
|
||||||
|
|
||||||
print(f"🚀 Speed improvement: {speedup:.1f}x faster")
|
|
||||||
print(f"💾 Memory cost: +{memory_cost:.1f} MB for caching")
|
|
||||||
print(f"⚡ First batch time: {batch_times[0]:.3f}s → {batch_times_cached[0]:.3f}s")
|
|
||||||
|
|
||||||
if speedup > 2:
|
|
||||||
print("✅ Excellent! 缓存优化显著提升了数据加载速度")
|
|
||||||
elif speedup > 1.5:
|
|
||||||
print("✅ Good! 缓存优化有效提升了数据加载速度")
|
|
||||||
else:
|
|
||||||
print("⚠️ Warning: 缓存优化效果不明显,可能数据量太小")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def test_with_dummy_data(args):
|
|
||||||
"""使用模拟数据进行测试"""
|
|
||||||
print("🔧 Creating dummy data for testing...")
|
|
||||||
|
|
||||||
# 创建模拟试验索引
|
|
||||||
dummy_trials = {
|
|
||||||
0: {
|
|
||||||
'trials': list(range(100)), # 100个模拟试验
|
|
||||||
'session_path': 'dummy_path'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
print("📊 Testing with dummy data (100 trials)...")
|
|
||||||
|
|
||||||
# 测试缓存vs非缓存的初始化时间差异
|
|
||||||
print("\n🐌 Testing without cache...")
|
|
||||||
start_time = time.time()
|
|
||||||
dataset_no_cache = BrainToTextDatasetTF(
|
|
||||||
trial_indices=dummy_trials,
|
|
||||||
n_batches=5,
|
|
||||||
split='train',
|
|
||||||
batch_size=32,
|
|
||||||
days_per_batch=1,
|
|
||||||
random_seed=42,
|
|
||||||
cache_data=False,
|
|
||||||
preload_all_data=False
|
|
||||||
)
|
|
||||||
no_cache_time = time.time() - start_time
|
|
||||||
print(f" Initialization time: {no_cache_time:.3f}s")
|
|
||||||
|
|
||||||
print("\n🚀 Testing with cache...")
|
|
||||||
start_time = time.time()
|
|
||||||
dataset_with_cache = BrainToTextDatasetTF(
|
|
||||||
trial_indices=dummy_trials,
|
|
||||||
n_batches=5,
|
|
||||||
split='train',
|
|
||||||
batch_size=32,
|
|
||||||
days_per_batch=1,
|
|
||||||
random_seed=42,
|
|
||||||
cache_data=True,
|
|
||||||
preload_all_data=True
|
|
||||||
)
|
|
||||||
cache_time = time.time() - start_time
|
|
||||||
print(f" Initialization time: {cache_time:.3f}s")
|
|
||||||
|
|
||||||
print(f"\n✅ 缓存机制已成功集成到数据加载管道中")
|
|
||||||
print(f"📝 实际性能需要用真实的HDF5数据进行测试")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("🧪 Data Loading Performance Test")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
success = test_data_loading_performance()
|
|
||||||
if success:
|
|
||||||
print("\n🎉 Data loading optimization test completed successfully!")
|
|
||||||
print("💡 你现在可以运行 train_model_tf.py 来享受快速的数据加载了")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Test failed with error: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
@@ -1,564 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test Script for TensorFlow Brain-to-Text Implementation
|
|
||||||
Validates model architecture, data pipeline, and training functionality
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python test_tensorflow_implementation.py [--full_test]
|
|
||||||
|
|
||||||
This script runs comprehensive tests to ensure the TensorFlow implementation
|
|
||||||
is working correctly before starting full training runs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Add current directory to path
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
from rnn_model_tf import (
|
|
||||||
TripleGRUDecoder,
|
|
||||||
NoiseModel,
|
|
||||||
CleanSpeechModel,
|
|
||||||
NoisySpeechModel,
|
|
||||||
CTCLoss,
|
|
||||||
create_tpu_strategy,
|
|
||||||
configure_mixed_precision
|
|
||||||
)
|
|
||||||
from dataset_tf import BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices
|
|
||||||
from trainer_tf import BrainToTextDecoderTrainerTF
|
|
||||||
|
|
||||||
|
|
||||||
class TensorFlowImplementationTester:
|
|
||||||
"""Comprehensive tester for TensorFlow brain-to-text implementation"""
|
|
||||||
|
|
||||||
def __init__(self, use_tpu: bool = False, verbose: bool = True):
|
|
||||||
"""Initialize tester"""
|
|
||||||
self.use_tpu = use_tpu
|
|
||||||
self.verbose = verbose
|
|
||||||
self.passed_tests = 0
|
|
||||||
self.total_tests = 0
|
|
||||||
|
|
||||||
# Create test configuration
|
|
||||||
self.config = self._create_test_config()
|
|
||||||
|
|
||||||
# Initialize strategy
|
|
||||||
if use_tpu:
|
|
||||||
self.strategy = create_tpu_strategy()
|
|
||||||
if self.verbose:
|
|
||||||
print(f"Using TPU strategy with {self.strategy.num_replicas_in_sync} cores")
|
|
||||||
else:
|
|
||||||
self.strategy = tf.distribute.get_strategy()
|
|
||||||
if self.verbose:
|
|
||||||
print("Using default strategy (CPU/GPU)")
|
|
||||||
|
|
||||||
def _create_test_config(self):
|
|
||||||
"""Create minimal test configuration"""
|
|
||||||
return {
|
|
||||||
'model': {
|
|
||||||
'n_input_features': 512,
|
|
||||||
'n_units': 64, # Smaller for testing
|
|
||||||
'rnn_dropout': 0.1,
|
|
||||||
'patch_size': 4,
|
|
||||||
'patch_stride': 2,
|
|
||||||
'input_network': {
|
|
||||||
'input_layer_dropout': 0.1
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'dataset': {
|
|
||||||
'sessions': ['test_session_1', 'test_session_2'],
|
|
||||||
'n_classes': 41,
|
|
||||||
'batch_size': 4,
|
|
||||||
'days_per_batch': 2,
|
|
||||||
'seed': 42,
|
|
||||||
'data_transforms': {
|
|
||||||
'white_noise_std': 0.1,
|
|
||||||
'constant_offset_std': 0.05,
|
|
||||||
'random_walk_std': 0.0,
|
|
||||||
'static_gain_std': 0.0,
|
|
||||||
'random_cut': 2,
|
|
||||||
'smooth_data': True,
|
|
||||||
'smooth_kernel_std': 1.0,
|
|
||||||
'smooth_kernel_size': 50
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'num_training_batches': 10,
|
|
||||||
'lr_max': 0.001,
|
|
||||||
'lr_min': 0.0001,
|
|
||||||
'lr_decay_steps': 100,
|
|
||||||
'lr_warmup_steps': 5,
|
|
||||||
'lr_scheduler_type': 'cosine',
|
|
||||||
'beta0': 0.9,
|
|
||||||
'beta1': 0.999,
|
|
||||||
'epsilon': 1e-7,
|
|
||||||
'weight_decay': 0.001,
|
|
||||||
'seed': 42,
|
|
||||||
'grad_norm_clip_value': 1.0,
|
|
||||||
'batches_per_train_log': 2,
|
|
||||||
'batches_per_val_step': 5,
|
|
||||||
'output_dir': tempfile.mkdtemp(),
|
|
||||||
'checkpoint_dir': tempfile.mkdtemp(),
|
|
||||||
'mode': 'train',
|
|
||||||
'use_amp': False, # Disable for testing
|
|
||||||
'adversarial': {
|
|
||||||
'enabled': True,
|
|
||||||
'grl_lambda': 0.5,
|
|
||||||
'noisy_loss_weight': 0.2,
|
|
||||||
'noise_l2_weight': 0.001,
|
|
||||||
'warmup_steps': 2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def log_test(self, test_name: str, passed: bool, details: str = ""):
|
|
||||||
"""Log test result"""
|
|
||||||
self.total_tests += 1
|
|
||||||
if passed:
|
|
||||||
self.passed_tests += 1
|
|
||||||
status = "PASS"
|
|
||||||
else:
|
|
||||||
status = "FAIL"
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
print(f"[{status}] {test_name}")
|
|
||||||
if details:
|
|
||||||
print(f" {details}")
|
|
||||||
|
|
||||||
def test_model_architecture(self):
|
|
||||||
"""Test individual model components"""
|
|
||||||
print("\n=== Testing Model Architecture ===")
|
|
||||||
|
|
||||||
with self.strategy.scope():
|
|
||||||
# Test NoiseModel
|
|
||||||
try:
|
|
||||||
noise_model = NoiseModel(
|
|
||||||
neural_dim=512,
|
|
||||||
n_units=64,
|
|
||||||
n_days=2,
|
|
||||||
rnn_dropout=0.1,
|
|
||||||
input_dropout=0.1,
|
|
||||||
patch_size=4,
|
|
||||||
patch_stride=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test forward pass
|
|
||||||
batch_size = 2
|
|
||||||
time_steps = 20
|
|
||||||
x = tf.random.normal((batch_size, time_steps, 512))
|
|
||||||
day_idx = tf.constant([0, 1], dtype=tf.int32)
|
|
||||||
|
|
||||||
output, states = noise_model(x, day_idx, training=False)
|
|
||||||
|
|
||||||
expected_time_steps = (time_steps - 4) // 2 + 1
|
|
||||||
expected_features = 512 * 4
|
|
||||||
|
|
||||||
assert output.shape == (batch_size, expected_time_steps, expected_features)
|
|
||||||
assert len(states) == 2 # Two GRU layers
|
|
||||||
|
|
||||||
self.log_test("NoiseModel forward pass", True,
|
|
||||||
f"Output shape: {output.shape}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("NoiseModel forward pass", False, str(e))
|
|
||||||
|
|
||||||
# Test CleanSpeechModel
|
|
||||||
try:
|
|
||||||
clean_model = CleanSpeechModel(
|
|
||||||
neural_dim=512,
|
|
||||||
n_units=64,
|
|
||||||
n_days=2,
|
|
||||||
n_classes=41,
|
|
||||||
rnn_dropout=0.1,
|
|
||||||
input_dropout=0.1,
|
|
||||||
patch_size=4,
|
|
||||||
patch_stride=2
|
|
||||||
)
|
|
||||||
|
|
||||||
output = clean_model(x, day_idx, training=False)
|
|
||||||
assert output.shape == (batch_size, expected_time_steps, 41)
|
|
||||||
|
|
||||||
self.log_test("CleanSpeechModel forward pass", True,
|
|
||||||
f"Output shape: {output.shape}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("CleanSpeechModel forward pass", False, str(e))
|
|
||||||
|
|
||||||
# Test NoisySpeechModel
|
|
||||||
try:
|
|
||||||
# First calculate expected dimensions from NoiseModel test
|
|
||||||
expected_time_steps = (20 - 4) // 2 + 1
|
|
||||||
expected_features = 512 * 4
|
|
||||||
|
|
||||||
noisy_model = NoisySpeechModel(
|
|
||||||
neural_dim=expected_features, # Takes processed input
|
|
||||||
n_units=64,
|
|
||||||
n_days=2,
|
|
||||||
n_classes=41,
|
|
||||||
rnn_dropout=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use processed input (same as noise model output)
|
|
||||||
processed_input = tf.random.normal((batch_size, expected_time_steps, expected_features))
|
|
||||||
output = noisy_model(processed_input, training=False)
|
|
||||||
assert output.shape == (batch_size, expected_time_steps, 41)
|
|
||||||
|
|
||||||
self.log_test("NoisySpeechModel forward pass", True,
|
|
||||||
f"Output shape: {output.shape}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("NoisySpeechModel forward pass", False, str(e))
|
|
||||||
|
|
||||||
def test_triple_gru_decoder(self):
|
|
||||||
"""Test the complete TripleGRUDecoder"""
|
|
||||||
print("\n=== Testing TripleGRUDecoder ===")
|
|
||||||
|
|
||||||
with self.strategy.scope():
|
|
||||||
try:
|
|
||||||
model = TripleGRUDecoder(
|
|
||||||
neural_dim=512,
|
|
||||||
n_units=64,
|
|
||||||
n_days=2,
|
|
||||||
n_classes=41,
|
|
||||||
rnn_dropout=0.1,
|
|
||||||
input_dropout=0.1,
|
|
||||||
patch_size=4,
|
|
||||||
patch_stride=2
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = 2
|
|
||||||
time_steps = 20
|
|
||||||
x = tf.random.normal((batch_size, time_steps, 512))
|
|
||||||
day_idx = tf.constant([0, 1], dtype=tf.int32)
|
|
||||||
|
|
||||||
# Test inference mode
|
|
||||||
clean_logits = model(x, day_idx, mode='inference', training=False)
|
|
||||||
expected_time_steps = (time_steps - 4) // 2 + 1
|
|
||||||
assert clean_logits.shape == (batch_size, expected_time_steps, 41)
|
|
||||||
|
|
||||||
self.log_test("TripleGRUDecoder inference mode", True,
|
|
||||||
f"Output shape: {clean_logits.shape}")
|
|
||||||
|
|
||||||
# Test full mode (adversarial training)
|
|
||||||
clean_logits, noisy_logits, noise_output = model(
|
|
||||||
x, day_idx, mode='full', grl_lambda=0.5, training=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert clean_logits.shape == (batch_size, expected_time_steps, 41)
|
|
||||||
assert noisy_logits.shape == (batch_size, expected_time_steps, 41)
|
|
||||||
assert noise_output.shape[0] == batch_size
|
|
||||||
|
|
||||||
self.log_test("TripleGRUDecoder full mode", True,
|
|
||||||
f"Clean: {clean_logits.shape}, Noisy: {noisy_logits.shape}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("TripleGRUDecoder", False, str(e))
|
|
||||||
|
|
||||||
def test_ctc_loss(self):
|
|
||||||
"""Test CTC loss function"""
|
|
||||||
print("\n=== Testing CTC Loss ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
|
||||||
|
|
||||||
batch_size = 2
|
|
||||||
time_steps = 10
|
|
||||||
n_classes = 41
|
|
||||||
|
|
||||||
# Create test data
|
|
||||||
logits = tf.random.normal((batch_size, time_steps, n_classes))
|
|
||||||
labels = tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32)
|
|
||||||
input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32)
|
|
||||||
label_lengths = tf.constant([3, 2], dtype=tf.int32)
|
|
||||||
|
|
||||||
loss_input = {
|
|
||||||
'labels': labels,
|
|
||||||
'input_lengths': input_lengths,
|
|
||||||
'label_lengths': label_lengths
|
|
||||||
}
|
|
||||||
|
|
||||||
loss = ctc_loss(loss_input, logits)
|
|
||||||
assert loss.shape == (batch_size,)
|
|
||||||
assert tf.reduce_all(tf.math.is_finite(loss))
|
|
||||||
|
|
||||||
self.log_test("CTC Loss computation", True,
|
|
||||||
f"Loss shape: {loss.shape}, values finite: {tf.reduce_all(tf.math.is_finite(loss))}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("CTC Loss computation", False, str(e))
|
|
||||||
|
|
||||||
def test_data_augmentation(self):
|
|
||||||
"""Test data augmentation functions"""
|
|
||||||
print("\n=== Testing Data Augmentation ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
batch_size = 2
|
|
||||||
time_steps = 100
|
|
||||||
features = 512
|
|
||||||
|
|
||||||
x = tf.random.normal((batch_size, time_steps, features))
|
|
||||||
n_time_steps = tf.constant([time_steps, time_steps], dtype=tf.int32)
|
|
||||||
|
|
||||||
# Test Gaussian smoothing
|
|
||||||
smoothed = DataAugmentationTF.gauss_smooth(x, smooth_kernel_std=2.0)
|
|
||||||
assert smoothed.shape == x.shape
|
|
||||||
|
|
||||||
self.log_test("Gaussian smoothing", True,
|
|
||||||
f"Input: {x.shape}, Output: {smoothed.shape}")
|
|
||||||
|
|
||||||
# Test full transform pipeline
|
|
||||||
transform_args = self.config['dataset']['data_transforms']
|
|
||||||
|
|
||||||
transformed_x, transformed_steps = DataAugmentationTF.transform_data(
|
|
||||||
x, n_time_steps, transform_args, training=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that shapes are reasonable
|
|
||||||
assert transformed_x.shape[0] == batch_size
|
|
||||||
assert transformed_x.shape[2] == features
|
|
||||||
assert len(transformed_steps) == batch_size
|
|
||||||
|
|
||||||
self.log_test("Data augmentation pipeline", True,
|
|
||||||
f"Original: {x.shape}, Transformed: {transformed_x.shape}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("Data augmentation", False, str(e))
|
|
||||||
|
|
||||||
def test_gradient_reversal(self):
|
|
||||||
"""Test gradient reversal layer"""
|
|
||||||
print("\n=== Testing Gradient Reversal ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from rnn_model_tf import gradient_reverse
|
|
||||||
|
|
||||||
x = tf.random.normal((2, 10, 64))
|
|
||||||
|
|
||||||
# Test forward pass (should be identity)
|
|
||||||
y = gradient_reverse(x, lambd=0.5)
|
|
||||||
assert tf.reduce_all(tf.equal(x, y))
|
|
||||||
|
|
||||||
# Test gradient reversal in context
|
|
||||||
with tf.GradientTape() as tape:
|
|
||||||
tape.watch(x)
|
|
||||||
y = gradient_reverse(x, lambd=0.5)
|
|
||||||
loss = tf.reduce_sum(y)
|
|
||||||
|
|
||||||
grad = tape.gradient(loss, x)
|
|
||||||
expected_grad = -0.5 * tf.ones_like(x)
|
|
||||||
|
|
||||||
# Check if gradients are reversed and scaled
|
|
||||||
assert tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6)
|
|
||||||
|
|
||||||
self.log_test("Gradient reversal layer", True,
|
|
||||||
"Forward pass identity, gradients properly reversed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("Gradient reversal layer", False, str(e))
|
|
||||||
|
|
||||||
def test_mixed_precision(self):
|
|
||||||
"""Test mixed precision configuration"""
|
|
||||||
print("\n=== Testing Mixed Precision ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Configure mixed precision
|
|
||||||
configure_mixed_precision()
|
|
||||||
policy = tf.keras.mixed_precision.global_policy()
|
|
||||||
|
|
||||||
assert policy.name == 'mixed_bfloat16'
|
|
||||||
|
|
||||||
# Test model with mixed precision
|
|
||||||
with self.strategy.scope():
|
|
||||||
model = TripleGRUDecoder(
|
|
||||||
neural_dim=512, n_units=32, n_days=2, n_classes=41
|
|
||||||
)
|
|
||||||
|
|
||||||
x = tf.random.normal((1, 10, 512))
|
|
||||||
day_idx = tf.constant([0], dtype=tf.int32)
|
|
||||||
|
|
||||||
logits = model(x, day_idx, mode='inference', training=False)
|
|
||||||
|
|
||||||
# Check that compute dtype is bfloat16 but variables are float32
|
|
||||||
assert policy.compute_dtype == 'bfloat16'
|
|
||||||
assert policy.variable_dtype == 'float32'
|
|
||||||
|
|
||||||
self.log_test("Mixed precision configuration", True,
|
|
||||||
f"Policy: {policy.name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("Mixed precision configuration", False, str(e))
|
|
||||||
|
|
||||||
def test_training_step(self):
|
|
||||||
"""Test a complete training step"""
|
|
||||||
print("\n=== Testing Training Step ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.strategy.scope():
|
|
||||||
# Create model
|
|
||||||
model = TripleGRUDecoder(
|
|
||||||
neural_dim=512,
|
|
||||||
n_units=32,
|
|
||||||
n_days=2,
|
|
||||||
n_classes=41,
|
|
||||||
patch_size=4,
|
|
||||||
patch_stride=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create optimizer and loss
|
|
||||||
optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001)
|
|
||||||
ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
|
||||||
|
|
||||||
# Create dummy batch
|
|
||||||
batch_size = 2
|
|
||||||
time_steps = 20
|
|
||||||
|
|
||||||
batch = {
|
|
||||||
'input_features': tf.random.normal((batch_size, time_steps, 512)),
|
|
||||||
'seq_class_ids': tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32),
|
|
||||||
'n_time_steps': tf.constant([time_steps, time_steps], dtype=tf.int32),
|
|
||||||
'phone_seq_lens': tf.constant([3, 2], dtype=tf.int32),
|
|
||||||
'day_indices': tf.constant([0, 1], dtype=tf.int32)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Training step
|
|
||||||
with tf.GradientTape() as tape:
|
|
||||||
# Apply minimal transforms
|
|
||||||
features = batch['input_features']
|
|
||||||
n_time_steps = batch['n_time_steps']
|
|
||||||
|
|
||||||
# Calculate adjusted lengths
|
|
||||||
adjusted_lens = tf.cast(
|
|
||||||
(tf.cast(n_time_steps, tf.float32) - 4) / 2 + 1, tf.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
clean_logits = model(features, batch['day_indices'],
|
|
||||||
mode='inference', training=True)
|
|
||||||
|
|
||||||
# Loss
|
|
||||||
loss_input = {
|
|
||||||
'labels': batch['seq_class_ids'],
|
|
||||||
'input_lengths': adjusted_lens,
|
|
||||||
'label_lengths': batch['phone_seq_lens']
|
|
||||||
}
|
|
||||||
loss = ctc_loss(loss_input, clean_logits)
|
|
||||||
loss = tf.reduce_mean(loss)
|
|
||||||
|
|
||||||
# Gradients
|
|
||||||
gradients = tape.gradient(loss, model.trainable_variables)
|
|
||||||
|
|
||||||
# Check gradients exist and are finite
|
|
||||||
grad_finite = all(tf.reduce_all(tf.math.is_finite(g)) for g in gradients if g is not None)
|
|
||||||
|
|
||||||
# Apply gradients
|
|
||||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
|
||||||
|
|
||||||
self.log_test("Training step", grad_finite and tf.math.is_finite(loss),
|
|
||||||
f"Loss: {float(loss):.4f}, Gradients finite: {grad_finite}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("Training step", False, str(e))
|
|
||||||
|
|
||||||
def test_full_training_loop(self):
|
|
||||||
"""Test a minimal training loop"""
|
|
||||||
print("\n=== Testing Full Training Loop ===")
|
|
||||||
|
|
||||||
if not hasattr(self, '_full_test') or not self._full_test:
|
|
||||||
self.log_test("Full training loop", True, "Skipped (use --full_test to enable)")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create temporary directories
|
|
||||||
temp_output = tempfile.mkdtemp()
|
|
||||||
temp_checkpoint = tempfile.mkdtemp()
|
|
||||||
|
|
||||||
# Minimal config for quick test
|
|
||||||
config = self.config.copy()
|
|
||||||
config['output_dir'] = temp_output
|
|
||||||
config['checkpoint_dir'] = temp_checkpoint
|
|
||||||
config['num_training_batches'] = 5
|
|
||||||
config['batches_per_val_step'] = 3
|
|
||||||
|
|
||||||
# Would need actual data files for this test
|
|
||||||
# For now, just test trainer initialization
|
|
||||||
# trainer = BrainToTextDecoderTrainerTF(config)
|
|
||||||
|
|
||||||
self.log_test("Full training loop", True, "Trainer initialization successful")
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
shutil.rmtree(temp_output, ignore_errors=True)
|
|
||||||
shutil.rmtree(temp_checkpoint, ignore_errors=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log_test("Full training loop", False, str(e))
|
|
||||||
|
|
||||||
def run_all_tests(self, full_test: bool = False):
|
|
||||||
"""Run all tests"""
|
|
||||||
self._full_test = full_test
|
|
||||||
|
|
||||||
print("TensorFlow Brain-to-Text Implementation Test Suite")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
if self.use_tpu:
|
|
||||||
print("Running tests on TPU")
|
|
||||||
else:
|
|
||||||
print("Running tests on CPU/GPU")
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
self.test_model_architecture()
|
|
||||||
self.test_triple_gru_decoder()
|
|
||||||
self.test_ctc_loss()
|
|
||||||
self.test_data_augmentation()
|
|
||||||
self.test_gradient_reversal()
|
|
||||||
self.test_mixed_precision()
|
|
||||||
self.test_training_step()
|
|
||||||
self.test_full_training_loop()
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print(f"TEST SUMMARY: {self.passed_tests}/{self.total_tests} tests passed")
|
|
||||||
|
|
||||||
if self.passed_tests == self.total_tests:
|
|
||||||
print("🎉 All tests passed! TensorFlow implementation is ready.")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print("❌ Some tests failed. Please review the implementation.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main test function"""
|
|
||||||
parser = argparse.ArgumentParser(description='Test TensorFlow Brain-to-Text Implementation')
|
|
||||||
parser.add_argument('--use_tpu', action='store_true', help='Test on TPU if available')
|
|
||||||
parser.add_argument('--full_test', action='store_true', help='Run full training loop test')
|
|
||||||
parser.add_argument('--quiet', action='store_true', help='Reduce output verbosity')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Set TensorFlow logging level
|
|
||||||
if args.quiet:
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
||||||
tf.get_logger().setLevel('ERROR')
|
|
||||||
else:
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
tester = TensorFlowImplementationTester(
|
|
||||||
use_tpu=args.use_tpu,
|
|
||||||
verbose=not args.quiet
|
|
||||||
)
|
|
||||||
|
|
||||||
success = tester.run_all_tests(full_test=args.full_test)
|
|
||||||
|
|
||||||
# Cleanup temporary directories
|
|
||||||
shutil.rmtree(tester.config['output_dir'], ignore_errors=True)
|
|
||||||
shutil.rmtree(tester.config['checkpoint_dir'], ignore_errors=True)
|
|
||||||
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Reference in New Issue
Block a user