From 7c272b7c5b53750e9aab463238077f5941f8f017 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Mon, 20 Oct 2025 01:37:22 +0800 Subject: [PATCH] Remove test scripts for data loading and TensorFlow implementation --- model_training_nnn_tpu/test_data_loading.py | 254 -------- .../test_tensorflow_implementation.py | 564 ------------------ 2 files changed, 818 deletions(-) delete mode 100644 model_training_nnn_tpu/test_data_loading.py delete mode 100644 model_training_nnn_tpu/test_tensorflow_implementation.py diff --git a/model_training_nnn_tpu/test_data_loading.py b/model_training_nnn_tpu/test_data_loading.py deleted file mode 100644 index 360db43..0000000 --- a/model_training_nnn_tpu/test_data_loading.py +++ /dev/null @@ -1,254 +0,0 @@ -#!/usr/bin/env python3 -""" -测试优化后的数据加载管道性能 -Test script for optimized data loading pipeline performance -""" - -import os -import time -import psutil -import tensorflow as tf -from omegaconf import OmegaConf -from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn - -def get_memory_usage(): - """获取当前内存使用情况""" - process = psutil.Process() - memory_info = process.memory_info() - return memory_info.rss / 1024 / 1024 # MB - -def test_data_loading_performance(): - """测试数据加载性能对比""" - - # 加载配置 - config_path = "../rnn_args.yaml" - if not os.path.exists(config_path): - print("❌ Configuration file not found. Creating minimal test config...") - # 创建最小测试配置 - args = { - 'dataset': { - 'dataset_dir': '../data/hdf5_data_final', - 'sessions': ['t15.2022.03.14', 't15.2022.03.16'], - 'batch_size': 32, - 'days_per_batch': 1, - 'seed': 42, - 'data_transforms': { - 'smooth_data': False, - 'white_noise_std': 0.0, - 'constant_offset_std': 0.0, - 'random_walk_std': 0.0, - 'static_gain_std': 0.0, - 'random_cut': 0 - } - }, - 'num_training_batches': 10 # 只测试10个batch - } - else: - args = OmegaConf.load(config_path) - args = OmegaConf.to_container(args, resolve=True) - # 限制测试batch数量 - args['num_training_batches'] = 10 - - print("🔍 Starting data loading performance test...") - print(f"📊 Test configuration: {args['num_training_batches']} batches, batch_size={args['dataset']['batch_size']}") - - # 获取文件路径 - train_file_paths = [ - os.path.join(args["dataset"]["dataset_dir"], s, 'data_train.hdf5') - for s in args['dataset']['sessions'] - ] - - print(f"📁 Testing with files: {train_file_paths}") - - # 检查文件是否存在 - missing_files = [f for f in train_file_paths if not os.path.exists(f)] - if missing_files: - print(f"❌ Missing files: {missing_files}") - print("⚠️ Creating dummy test data...") - return test_with_dummy_data(args) - - # 分割数据 - print("🔄 Splitting data...") - train_trials, _ = train_test_split_indices( - file_paths=train_file_paths, - test_percentage=0, - seed=args['dataset']['seed'] - ) - - print(f"📈 Found {sum(len(trials['trials']) for trials in train_trials.values())} training trials") - - # 测试1: 不使用缓存 - print("\n" + "="*60) - print("🐌 TEST 1: 标准数据加载 (无缓存)") - print("="*60) - - initial_memory = get_memory_usage() - start_time = time.time() - - dataset_no_cache = BrainToTextDatasetTF( - trial_indices=train_trials, - n_batches=args['num_training_batches'], - split='train', - batch_size=args['dataset']['batch_size'], - days_per_batch=args['dataset']['days_per_batch'], - random_seed=args['dataset']['seed'], - cache_data=False, # 禁用缓存 - preload_all_data=False # 禁用预加载 - ) - - tf_dataset_no_cache = create_input_fn( - dataset_no_cache, - args['dataset']['data_transforms'], - training=True - ) - - # 测试前3个batch的加载时间 - batch_times = [] - for i, batch in enumerate(tf_dataset_no_cache.take(3)): - batch_start = time.time() - # 触发实际数据加载 - _ = batch['input_features'].numpy() - batch_time = time.time() - batch_start - batch_times.append(batch_time) - print(f" Batch {i}: {batch_time:.3f}s") - - no_cache_time = time.time() - start_time - no_cache_memory = get_memory_usage() - initial_memory - - print(f"💾 Memory usage: +{no_cache_memory:.1f} MB") - print(f"⏱️ Total time: {no_cache_time:.3f}s") - print(f"📊 Avg batch time: {sum(batch_times)/len(batch_times):.3f}s") - - # 测试2: 使用预加载缓存 - print("\n" + "="*60) - print("🚀 TEST 2: 优化数据加载 (全缓存预加载)") - print("="*60) - - initial_memory = get_memory_usage() - start_time = time.time() - - dataset_with_cache = BrainToTextDatasetTF( - trial_indices=train_trials, - n_batches=args['num_training_batches'], - split='train', - batch_size=args['dataset']['batch_size'], - days_per_batch=args['dataset']['days_per_batch'], - random_seed=args['dataset']['seed'], - cache_data=True, # 启用缓存 - preload_all_data=True # 启用预加载 - ) - - preload_time = time.time() - start_time - preload_memory = get_memory_usage() - initial_memory - - print(f"📝 Preloading completed in {preload_time:.3f}s") - print(f"💾 Preloading memory: +{preload_memory:.1f} MB") - - tf_dataset_with_cache = create_input_fn( - dataset_with_cache, - args['dataset']['data_transforms'], - training=True - ) - - # 测试前3个batch的加载时间 - batch_start_time = time.time() - batch_times_cached = [] - for i, batch in enumerate(tf_dataset_with_cache.take(3)): - batch_start = time.time() - # 触发实际数据加载 - _ = batch['input_features'].numpy() - batch_time = time.time() - batch_start - batch_times_cached.append(batch_time) - print(f" Batch {i}: {batch_time:.3f}s") - - cached_batch_time = time.time() - batch_start_time - cached_memory = get_memory_usage() - initial_memory - - print(f"💾 Total memory usage: +{cached_memory:.1f} MB") - print(f"⏱️ Batch loading time: {cached_batch_time:.3f}s") - print(f"📊 Avg batch time: {sum(batch_times_cached)/len(batch_times_cached):.3f}s") - - # 性能对比 - print("\n" + "="*60) - print("📈 PERFORMANCE COMPARISON") - print("="*60) - - speedup = (sum(batch_times)/len(batch_times)) / (sum(batch_times_cached)/len(batch_times_cached)) - memory_cost = cached_memory - no_cache_memory - - print(f"🚀 Speed improvement: {speedup:.1f}x faster") - print(f"💾 Memory cost: +{memory_cost:.1f} MB for caching") - print(f"⚡ First batch time: {batch_times[0]:.3f}s → {batch_times_cached[0]:.3f}s") - - if speedup > 2: - print("✅ Excellent! 缓存优化显著提升了数据加载速度") - elif speedup > 1.5: - print("✅ Good! 缓存优化有效提升了数据加载速度") - else: - print("⚠️ Warning: 缓存优化效果不明显,可能数据量太小") - - return True - -def test_with_dummy_data(args): - """使用模拟数据进行测试""" - print("🔧 Creating dummy data for testing...") - - # 创建模拟试验索引 - dummy_trials = { - 0: { - 'trials': list(range(100)), # 100个模拟试验 - 'session_path': 'dummy_path' - } - } - - print("📊 Testing with dummy data (100 trials)...") - - # 测试缓存vs非缓存的初始化时间差异 - print("\n🐌 Testing without cache...") - start_time = time.time() - dataset_no_cache = BrainToTextDatasetTF( - trial_indices=dummy_trials, - n_batches=5, - split='train', - batch_size=32, - days_per_batch=1, - random_seed=42, - cache_data=False, - preload_all_data=False - ) - no_cache_time = time.time() - start_time - print(f" Initialization time: {no_cache_time:.3f}s") - - print("\n🚀 Testing with cache...") - start_time = time.time() - dataset_with_cache = BrainToTextDatasetTF( - trial_indices=dummy_trials, - n_batches=5, - split='train', - batch_size=32, - days_per_batch=1, - random_seed=42, - cache_data=True, - preload_all_data=True - ) - cache_time = time.time() - start_time - print(f" Initialization time: {cache_time:.3f}s") - - print(f"\n✅ 缓存机制已成功集成到数据加载管道中") - print(f"📝 实际性能需要用真实的HDF5数据进行测试") - - return True - -if __name__ == "__main__": - print("🧪 Data Loading Performance Test") - print("="*60) - - try: - success = test_data_loading_performance() - if success: - print("\n🎉 Data loading optimization test completed successfully!") - print("💡 你现在可以运行 train_model_tf.py 来享受快速的数据加载了") - except Exception as e: - print(f"\n❌ Test failed with error: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/model_training_nnn_tpu/test_tensorflow_implementation.py b/model_training_nnn_tpu/test_tensorflow_implementation.py deleted file mode 100644 index dc17f35..0000000 --- a/model_training_nnn_tpu/test_tensorflow_implementation.py +++ /dev/null @@ -1,564 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Script for TensorFlow Brain-to-Text Implementation -Validates model architecture, data pipeline, and training functionality - -Usage: - python test_tensorflow_implementation.py [--full_test] - -This script runs comprehensive tests to ensure the TensorFlow implementation -is working correctly before starting full training runs. -""" - -import os -import sys -import argparse -import numpy as np -import tensorflow as tf -from omegaconf import OmegaConf -import tempfile -import shutil - -# Add current directory to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from rnn_model_tf import ( - TripleGRUDecoder, - NoiseModel, - CleanSpeechModel, - NoisySpeechModel, - CTCLoss, - create_tpu_strategy, - configure_mixed_precision -) -from dataset_tf import BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices -from trainer_tf import BrainToTextDecoderTrainerTF - - -class TensorFlowImplementationTester: - """Comprehensive tester for TensorFlow brain-to-text implementation""" - - def __init__(self, use_tpu: bool = False, verbose: bool = True): - """Initialize tester""" - self.use_tpu = use_tpu - self.verbose = verbose - self.passed_tests = 0 - self.total_tests = 0 - - # Create test configuration - self.config = self._create_test_config() - - # Initialize strategy - if use_tpu: - self.strategy = create_tpu_strategy() - if self.verbose: - print(f"Using TPU strategy with {self.strategy.num_replicas_in_sync} cores") - else: - self.strategy = tf.distribute.get_strategy() - if self.verbose: - print("Using default strategy (CPU/GPU)") - - def _create_test_config(self): - """Create minimal test configuration""" - return { - 'model': { - 'n_input_features': 512, - 'n_units': 64, # Smaller for testing - 'rnn_dropout': 0.1, - 'patch_size': 4, - 'patch_stride': 2, - 'input_network': { - 'input_layer_dropout': 0.1 - } - }, - 'dataset': { - 'sessions': ['test_session_1', 'test_session_2'], - 'n_classes': 41, - 'batch_size': 4, - 'days_per_batch': 2, - 'seed': 42, - 'data_transforms': { - 'white_noise_std': 0.1, - 'constant_offset_std': 0.05, - 'random_walk_std': 0.0, - 'static_gain_std': 0.0, - 'random_cut': 2, - 'smooth_data': True, - 'smooth_kernel_std': 1.0, - 'smooth_kernel_size': 50 - } - }, - 'num_training_batches': 10, - 'lr_max': 0.001, - 'lr_min': 0.0001, - 'lr_decay_steps': 100, - 'lr_warmup_steps': 5, - 'lr_scheduler_type': 'cosine', - 'beta0': 0.9, - 'beta1': 0.999, - 'epsilon': 1e-7, - 'weight_decay': 0.001, - 'seed': 42, - 'grad_norm_clip_value': 1.0, - 'batches_per_train_log': 2, - 'batches_per_val_step': 5, - 'output_dir': tempfile.mkdtemp(), - 'checkpoint_dir': tempfile.mkdtemp(), - 'mode': 'train', - 'use_amp': False, # Disable for testing - 'adversarial': { - 'enabled': True, - 'grl_lambda': 0.5, - 'noisy_loss_weight': 0.2, - 'noise_l2_weight': 0.001, - 'warmup_steps': 2 - } - } - - def log_test(self, test_name: str, passed: bool, details: str = ""): - """Log test result""" - self.total_tests += 1 - if passed: - self.passed_tests += 1 - status = "PASS" - else: - status = "FAIL" - - if self.verbose: - print(f"[{status}] {test_name}") - if details: - print(f" {details}") - - def test_model_architecture(self): - """Test individual model components""" - print("\n=== Testing Model Architecture ===") - - with self.strategy.scope(): - # Test NoiseModel - try: - noise_model = NoiseModel( - neural_dim=512, - n_units=64, - n_days=2, - rnn_dropout=0.1, - input_dropout=0.1, - patch_size=4, - patch_stride=2 - ) - - # Test forward pass - batch_size = 2 - time_steps = 20 - x = tf.random.normal((batch_size, time_steps, 512)) - day_idx = tf.constant([0, 1], dtype=tf.int32) - - output, states = noise_model(x, day_idx, training=False) - - expected_time_steps = (time_steps - 4) // 2 + 1 - expected_features = 512 * 4 - - assert output.shape == (batch_size, expected_time_steps, expected_features) - assert len(states) == 2 # Two GRU layers - - self.log_test("NoiseModel forward pass", True, - f"Output shape: {output.shape}") - - except Exception as e: - self.log_test("NoiseModel forward pass", False, str(e)) - - # Test CleanSpeechModel - try: - clean_model = CleanSpeechModel( - neural_dim=512, - n_units=64, - n_days=2, - n_classes=41, - rnn_dropout=0.1, - input_dropout=0.1, - patch_size=4, - patch_stride=2 - ) - - output = clean_model(x, day_idx, training=False) - assert output.shape == (batch_size, expected_time_steps, 41) - - self.log_test("CleanSpeechModel forward pass", True, - f"Output shape: {output.shape}") - - except Exception as e: - self.log_test("CleanSpeechModel forward pass", False, str(e)) - - # Test NoisySpeechModel - try: - # First calculate expected dimensions from NoiseModel test - expected_time_steps = (20 - 4) // 2 + 1 - expected_features = 512 * 4 - - noisy_model = NoisySpeechModel( - neural_dim=expected_features, # Takes processed input - n_units=64, - n_days=2, - n_classes=41, - rnn_dropout=0.1 - ) - - # Use processed input (same as noise model output) - processed_input = tf.random.normal((batch_size, expected_time_steps, expected_features)) - output = noisy_model(processed_input, training=False) - assert output.shape == (batch_size, expected_time_steps, 41) - - self.log_test("NoisySpeechModel forward pass", True, - f"Output shape: {output.shape}") - - except Exception as e: - self.log_test("NoisySpeechModel forward pass", False, str(e)) - - def test_triple_gru_decoder(self): - """Test the complete TripleGRUDecoder""" - print("\n=== Testing TripleGRUDecoder ===") - - with self.strategy.scope(): - try: - model = TripleGRUDecoder( - neural_dim=512, - n_units=64, - n_days=2, - n_classes=41, - rnn_dropout=0.1, - input_dropout=0.1, - patch_size=4, - patch_stride=2 - ) - - batch_size = 2 - time_steps = 20 - x = tf.random.normal((batch_size, time_steps, 512)) - day_idx = tf.constant([0, 1], dtype=tf.int32) - - # Test inference mode - clean_logits = model(x, day_idx, mode='inference', training=False) - expected_time_steps = (time_steps - 4) // 2 + 1 - assert clean_logits.shape == (batch_size, expected_time_steps, 41) - - self.log_test("TripleGRUDecoder inference mode", True, - f"Output shape: {clean_logits.shape}") - - # Test full mode (adversarial training) - clean_logits, noisy_logits, noise_output = model( - x, day_idx, mode='full', grl_lambda=0.5, training=True - ) - - assert clean_logits.shape == (batch_size, expected_time_steps, 41) - assert noisy_logits.shape == (batch_size, expected_time_steps, 41) - assert noise_output.shape[0] == batch_size - - self.log_test("TripleGRUDecoder full mode", True, - f"Clean: {clean_logits.shape}, Noisy: {noisy_logits.shape}") - - except Exception as e: - self.log_test("TripleGRUDecoder", False, str(e)) - - def test_ctc_loss(self): - """Test CTC loss function""" - print("\n=== Testing CTC Loss ===") - - try: - ctc_loss = CTCLoss(blank_index=0, reduction='none') - - batch_size = 2 - time_steps = 10 - n_classes = 41 - - # Create test data - logits = tf.random.normal((batch_size, time_steps, n_classes)) - labels = tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32) - input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32) - label_lengths = tf.constant([3, 2], dtype=tf.int32) - - loss_input = { - 'labels': labels, - 'input_lengths': input_lengths, - 'label_lengths': label_lengths - } - - loss = ctc_loss(loss_input, logits) - assert loss.shape == (batch_size,) - assert tf.reduce_all(tf.math.is_finite(loss)) - - self.log_test("CTC Loss computation", True, - f"Loss shape: {loss.shape}, values finite: {tf.reduce_all(tf.math.is_finite(loss))}") - - except Exception as e: - self.log_test("CTC Loss computation", False, str(e)) - - def test_data_augmentation(self): - """Test data augmentation functions""" - print("\n=== Testing Data Augmentation ===") - - try: - batch_size = 2 - time_steps = 100 - features = 512 - - x = tf.random.normal((batch_size, time_steps, features)) - n_time_steps = tf.constant([time_steps, time_steps], dtype=tf.int32) - - # Test Gaussian smoothing - smoothed = DataAugmentationTF.gauss_smooth(x, smooth_kernel_std=2.0) - assert smoothed.shape == x.shape - - self.log_test("Gaussian smoothing", True, - f"Input: {x.shape}, Output: {smoothed.shape}") - - # Test full transform pipeline - transform_args = self.config['dataset']['data_transforms'] - - transformed_x, transformed_steps = DataAugmentationTF.transform_data( - x, n_time_steps, transform_args, training=True - ) - - # Check that shapes are reasonable - assert transformed_x.shape[0] == batch_size - assert transformed_x.shape[2] == features - assert len(transformed_steps) == batch_size - - self.log_test("Data augmentation pipeline", True, - f"Original: {x.shape}, Transformed: {transformed_x.shape}") - - except Exception as e: - self.log_test("Data augmentation", False, str(e)) - - def test_gradient_reversal(self): - """Test gradient reversal layer""" - print("\n=== Testing Gradient Reversal ===") - - try: - from rnn_model_tf import gradient_reverse - - x = tf.random.normal((2, 10, 64)) - - # Test forward pass (should be identity) - y = gradient_reverse(x, lambd=0.5) - assert tf.reduce_all(tf.equal(x, y)) - - # Test gradient reversal in context - with tf.GradientTape() as tape: - tape.watch(x) - y = gradient_reverse(x, lambd=0.5) - loss = tf.reduce_sum(y) - - grad = tape.gradient(loss, x) - expected_grad = -0.5 * tf.ones_like(x) - - # Check if gradients are reversed and scaled - assert tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6) - - self.log_test("Gradient reversal layer", True, - "Forward pass identity, gradients properly reversed") - - except Exception as e: - self.log_test("Gradient reversal layer", False, str(e)) - - def test_mixed_precision(self): - """Test mixed precision configuration""" - print("\n=== Testing Mixed Precision ===") - - try: - # Configure mixed precision - configure_mixed_precision() - policy = tf.keras.mixed_precision.global_policy() - - assert policy.name == 'mixed_bfloat16' - - # Test model with mixed precision - with self.strategy.scope(): - model = TripleGRUDecoder( - neural_dim=512, n_units=32, n_days=2, n_classes=41 - ) - - x = tf.random.normal((1, 10, 512)) - day_idx = tf.constant([0], dtype=tf.int32) - - logits = model(x, day_idx, mode='inference', training=False) - - # Check that compute dtype is bfloat16 but variables are float32 - assert policy.compute_dtype == 'bfloat16' - assert policy.variable_dtype == 'float32' - - self.log_test("Mixed precision configuration", True, - f"Policy: {policy.name}") - - except Exception as e: - self.log_test("Mixed precision configuration", False, str(e)) - - def test_training_step(self): - """Test a complete training step""" - print("\n=== Testing Training Step ===") - - try: - with self.strategy.scope(): - # Create model - model = TripleGRUDecoder( - neural_dim=512, - n_units=32, - n_days=2, - n_classes=41, - patch_size=4, - patch_stride=2 - ) - - # Create optimizer and loss - optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001) - ctc_loss = CTCLoss(blank_index=0, reduction='none') - - # Create dummy batch - batch_size = 2 - time_steps = 20 - - batch = { - 'input_features': tf.random.normal((batch_size, time_steps, 512)), - 'seq_class_ids': tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32), - 'n_time_steps': tf.constant([time_steps, time_steps], dtype=tf.int32), - 'phone_seq_lens': tf.constant([3, 2], dtype=tf.int32), - 'day_indices': tf.constant([0, 1], dtype=tf.int32) - } - - # Training step - with tf.GradientTape() as tape: - # Apply minimal transforms - features = batch['input_features'] - n_time_steps = batch['n_time_steps'] - - # Calculate adjusted lengths - adjusted_lens = tf.cast( - (tf.cast(n_time_steps, tf.float32) - 4) / 2 + 1, tf.int32 - ) - - # Forward pass - clean_logits = model(features, batch['day_indices'], - mode='inference', training=True) - - # Loss - loss_input = { - 'labels': batch['seq_class_ids'], - 'input_lengths': adjusted_lens, - 'label_lengths': batch['phone_seq_lens'] - } - loss = ctc_loss(loss_input, clean_logits) - loss = tf.reduce_mean(loss) - - # Gradients - gradients = tape.gradient(loss, model.trainable_variables) - - # Check gradients exist and are finite - grad_finite = all(tf.reduce_all(tf.math.is_finite(g)) for g in gradients if g is not None) - - # Apply gradients - optimizer.apply_gradients(zip(gradients, model.trainable_variables)) - - self.log_test("Training step", grad_finite and tf.math.is_finite(loss), - f"Loss: {float(loss):.4f}, Gradients finite: {grad_finite}") - - except Exception as e: - self.log_test("Training step", False, str(e)) - - def test_full_training_loop(self): - """Test a minimal training loop""" - print("\n=== Testing Full Training Loop ===") - - if not hasattr(self, '_full_test') or not self._full_test: - self.log_test("Full training loop", True, "Skipped (use --full_test to enable)") - return - - try: - # Create temporary directories - temp_output = tempfile.mkdtemp() - temp_checkpoint = tempfile.mkdtemp() - - # Minimal config for quick test - config = self.config.copy() - config['output_dir'] = temp_output - config['checkpoint_dir'] = temp_checkpoint - config['num_training_batches'] = 5 - config['batches_per_val_step'] = 3 - - # Would need actual data files for this test - # For now, just test trainer initialization - # trainer = BrainToTextDecoderTrainerTF(config) - - self.log_test("Full training loop", True, "Trainer initialization successful") - - # Cleanup - shutil.rmtree(temp_output, ignore_errors=True) - shutil.rmtree(temp_checkpoint, ignore_errors=True) - - except Exception as e: - self.log_test("Full training loop", False, str(e)) - - def run_all_tests(self, full_test: bool = False): - """Run all tests""" - self._full_test = full_test - - print("TensorFlow Brain-to-Text Implementation Test Suite") - print("=" * 60) - - if self.use_tpu: - print("Running tests on TPU") - else: - print("Running tests on CPU/GPU") - - # Run tests - self.test_model_architecture() - self.test_triple_gru_decoder() - self.test_ctc_loss() - self.test_data_augmentation() - self.test_gradient_reversal() - self.test_mixed_precision() - self.test_training_step() - self.test_full_training_loop() - - # Summary - print("\n" + "=" * 60) - print(f"TEST SUMMARY: {self.passed_tests}/{self.total_tests} tests passed") - - if self.passed_tests == self.total_tests: - print("🎉 All tests passed! TensorFlow implementation is ready.") - return True - else: - print("❌ Some tests failed. Please review the implementation.") - return False - - -def main(): - """Main test function""" - parser = argparse.ArgumentParser(description='Test TensorFlow Brain-to-Text Implementation') - parser.add_argument('--use_tpu', action='store_true', help='Test on TPU if available') - parser.add_argument('--full_test', action='store_true', help='Run full training loop test') - parser.add_argument('--quiet', action='store_true', help='Reduce output verbosity') - - args = parser.parse_args() - - # Set TensorFlow logging level - if args.quiet: - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - tf.get_logger().setLevel('ERROR') - else: - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - - # Run tests - tester = TensorFlowImplementationTester( - use_tpu=args.use_tpu, - verbose=not args.quiet - ) - - success = tester.run_all_tests(full_test=args.full_test) - - # Cleanup temporary directories - shutil.rmtree(tester.config['output_dir'], ignore_errors=True) - shutil.rmtree(tester.config['checkpoint_dir'], ignore_errors=True) - - sys.exit(0 if success else 1) - - -if __name__ == "__main__": - main() \ No newline at end of file