#!/usr/bin/env python3 """ TensorFlow Evaluation Script for Brain-to-Text RNN Model Optimized for TPU v5e-8 This script evaluates the TripleGRUDecoder model using TensorFlow and provides detailed metrics and analysis of model performance on test data. Usage: python evaluate_model_tf.py --model_path path/to/model --data_dir path/to/data Requirements: - TensorFlow >= 2.15.0 - TPU v5e-8 environment - Trained model checkpoint - Access to brain-to-text HDF5 dataset """ import argparse import os import sys import json import pickle import numpy as np import tensorflow as tf from typing import Dict, Any, List, Tuple from omegaconf import OmegaConf # Add the current directory to Python path for imports sys.path.append(os.path.dirname(os.path.abspath(__file__))) from trainer_tf import BrainToTextDecoderTrainerTF from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn from rnn_model_tf import create_tpu_strategy, configure_mixed_precision class BrainToTextEvaluatorTF: """ TensorFlow evaluator for brain-to-text model performance analysis """ def __init__(self, model_path: str, config: Dict[str, Any], eval_type: str = 'test'): """ Initialize evaluator Args: model_path: Path to trained model checkpoint config: Configuration dictionary eval_type: 'test' or 'val' evaluation type """ self.model_path = model_path self.config = config self.eval_type = eval_type # Initialize TPU strategy self.strategy = create_tpu_strategy() print(f"Evaluation using {self.strategy.num_replicas_in_sync} TPU cores") # Configure mixed precision if config.get('use_amp', True): configure_mixed_precision() # Load model with self.strategy.scope(): self.trainer = BrainToTextDecoderTrainerTF(config) self.trainer.load_checkpoint(model_path) print(f"Model loaded from: {model_path}") def evaluate_dataset(self, save_results: bool = True, return_predictions: bool = False) -> Dict[str, Any]: """ Evaluate model on specified dataset Args: save_results: Whether to save detailed results to file return_predictions: Whether to return individual predictions Returns: Dictionary containing evaluation metrics and optionally predictions """ print(f"Starting {self.eval_type} evaluation...") # Create evaluation dataset if self.eval_type == 'test': dataset_tf = self.trainer.val_dataset_tf # Using validation data as test else: dataset_tf = self.trainer.val_dataset_tf eval_dataset = create_input_fn( dataset_tf, self.config['dataset']['data_transforms'], training=False ) # Distribute dataset eval_dist_dataset = self.strategy.experimental_distribute_dataset(eval_dataset) # Run evaluation results = self._run_evaluation(eval_dist_dataset, return_predictions) # Calculate summary metrics summary_metrics = self._calculate_summary_metrics(results) print(f"Evaluation completed!") print(f"Overall PER: {summary_metrics['overall_per']:.4f}") print(f"Overall Loss: {summary_metrics['overall_loss']:.4f}") print(f"Total trials evaluated: {summary_metrics['total_trials']}") # Save results if requested if save_results: self._save_results(results, summary_metrics) return { 'summary_metrics': summary_metrics, 'detailed_results': results if return_predictions else None } def _run_evaluation(self, eval_dataset, return_predictions: bool) -> List[Dict[str, Any]]: """Run evaluation on distributed dataset""" all_results = [] batch_idx = 0 for batch in eval_dataset: batch_results = self.strategy.run(self._evaluation_step, args=(batch, return_predictions)) # Gather results from all replicas gathered_results = {} for key in batch_results.keys(): if key in ['logits', 'features'] and not return_predictions: continue # Skip large tensors if not needed values = self.strategy.experimental_local_results(batch_results[key]) if key in ['loss', 'edit_distance', 'seq_length']: # Scalar metrics - just take the values gathered_results[key] = [float(v.numpy()) for v in values] else: # Tensor data - concatenate across replicas gathered_results[key] = [v.numpy() for v in values] all_results.append(gathered_results) batch_idx += 1 if batch_idx % 10 == 0: print(f"Processed {batch_idx} batches...") return all_results @tf.function def _evaluation_step(self, batch, return_predictions: bool): """Single evaluation step""" features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indices = batch['day_indices'] # Apply data transformations (no augmentation) from dataset_tf import DataAugmentationTF features_transformed, n_time_steps_transformed = DataAugmentationTF.transform_data( features, n_time_steps, self.config['dataset']['data_transforms'], training=False ) # Calculate adjusted lengths for CTC adjusted_lens = tf.cast( (tf.cast(n_time_steps_transformed, tf.float32) - self.config['model']['patch_size']) / self.config['model']['patch_stride'] + 1, tf.int32 ) # Forward pass logits = self.trainer.model( features_transformed, day_indices, None, False, 'inference', training=False ) # Calculate loss loss_input = { 'labels': labels, 'input_lengths': adjusted_lens, 'label_lengths': phone_seq_lens } loss = self.trainer.ctc_loss(loss_input, logits) loss = tf.reduce_mean(loss) # Calculate edit distance for PER predicted_ids = tf.argmax(logits, axis=-1) batch_size = tf.shape(logits)[0] # Initialize metrics total_edit_distance = 0 total_seq_length = tf.reduce_sum(phone_seq_lens) # Decode predictions and calculate edit distance predictions = [] targets = [] for i in range(batch_size): # Get prediction for this sample pred_seq = predicted_ids[i, :adjusted_lens[i]] # Remove consecutive duplicates using tf.py_function for simplicity pred_seq_unique = tf.py_function( func=self._remove_consecutive_duplicates, inp=[pred_seq], Tout=tf.int64 ) # Remove blanks (assuming blank_index=0) pred_seq_clean = tf.boolean_mask(pred_seq_unique, pred_seq_unique != 0) # Get true sequence true_seq = labels[i, :phone_seq_lens[i]] # Calculate edit distance for this pair if tf.size(pred_seq_clean) > 0 and tf.size(true_seq) > 0: pred_sparse = tf.SparseTensor( indices=tf.expand_dims(tf.range(tf.size(pred_seq_clean), dtype=tf.int64), 1), values=tf.cast(pred_seq_clean, tf.int64), dense_shape=[tf.size(pred_seq_clean, out_type=tf.int64)] ) true_sparse = tf.SparseTensor( indices=tf.expand_dims(tf.range(tf.size(true_seq), dtype=tf.int64), 1), values=tf.cast(true_seq, tf.int64), dense_shape=[tf.size(true_seq, out_type=tf.int64)] ) edit_dist = tf.edit_distance(pred_sparse, true_sparse, normalize=False) total_edit_distance += edit_dist if return_predictions: predictions.append(pred_seq_clean) targets.append(true_seq) result = { 'loss': loss, 'edit_distance': total_edit_distance, 'seq_length': total_seq_length, 'day_indices': day_indices, 'n_time_steps': n_time_steps, 'phone_seq_lens': phone_seq_lens } if return_predictions: result.update({ 'logits': logits, 'predictions': predictions, 'targets': targets, 'features': features }) return result def _remove_consecutive_duplicates(self, seq): """Remove consecutive duplicate elements from sequence""" seq_np = seq.numpy() if len(seq_np) == 0: return tf.constant([], dtype=tf.int64) unique_seq = [seq_np[0]] for i in range(1, len(seq_np)): if seq_np[i] != seq_np[i-1]: unique_seq.append(seq_np[i]) return tf.constant(unique_seq, dtype=tf.int64) def _calculate_summary_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: """Calculate summary metrics from evaluation results""" total_loss = 0.0 total_edit_distance = 0 total_seq_length = 0 total_trials = 0 num_batches = len(results) # Day-specific metrics day_metrics = {} for batch_results in results: # Sum losses across replicas batch_loss = sum(batch_results['loss']) total_loss += batch_loss # Sum edit distances and sequence lengths batch_edit_dist = sum(batch_results['edit_distance']) batch_seq_len = sum(batch_results['seq_length']) total_edit_distance += batch_edit_dist total_seq_length += batch_seq_len # Count trials for day_indices_replica in batch_results['day_indices']: total_trials += len(day_indices_replica) # Track per-day metrics for i, day_idx in enumerate(day_indices_replica): day_idx = int(day_idx) if day_idx not in day_metrics: day_metrics[day_idx] = {'edit_distance': 0, 'seq_length': 0, 'trials': 0} day_metrics[day_idx]['trials'] += 1 # Calculate averages avg_loss = total_loss / max(num_batches, 1) overall_per = total_edit_distance / max(total_seq_length, 1e-6) # Calculate per-day PERs day_pers = {} for day_idx, metrics in day_metrics.items(): day_per = metrics['edit_distance'] / max(metrics['seq_length'], 1e-6) day_pers[day_idx] = { 'per': day_per, 'edit_distance': metrics['edit_distance'], 'seq_length': metrics['seq_length'], 'trials': metrics['trials'] } return { 'overall_per': float(overall_per), 'overall_loss': float(avg_loss), 'total_edit_distance': int(total_edit_distance), 'total_seq_length': int(total_seq_length), 'total_trials': total_trials, 'num_batches': num_batches, 'day_metrics': day_pers } def _save_results(self, detailed_results: List[Dict[str, Any]], summary_metrics: Dict[str, Any]): """Save evaluation results to files""" output_dir = self.config.get('output_dir', './eval_output') os.makedirs(output_dir, exist_ok=True) # Save summary metrics summary_path = os.path.join(output_dir, f'{self.eval_type}_summary_metrics.json') with open(summary_path, 'w') as f: json.dump(summary_metrics, f, indent=2) print(f"Summary metrics saved to: {summary_path}") # Save detailed results detailed_path = os.path.join(output_dir, f'{self.eval_type}_detailed_results.pkl') with open(detailed_path, 'wb') as f: pickle.dump(detailed_results, f) print(f"Detailed results saved to: {detailed_path}") # Save per-day breakdown if 'day_metrics' in summary_metrics: day_breakdown_path = os.path.join(output_dir, f'{self.eval_type}_day_breakdown.json') with open(day_breakdown_path, 'w') as f: json.dump(summary_metrics['day_metrics'], f, indent=2) print(f"Per-day breakdown saved to: {day_breakdown_path}") def main(): """Main evaluation function""" parser = argparse.ArgumentParser( description='Evaluate Brain-to-Text RNN Model with TensorFlow on TPU v5e-8', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( '--model_path', required=True, help='Path to trained model checkpoint (without extension)' ) parser.add_argument( '--config_path', default='rnn_args.yaml', help='Path to model configuration file' ) parser.add_argument( '--data_dir', default=None, help='Override data directory from config' ) parser.add_argument( '--eval_type', choices=['test', 'val'], default='test', help='Type of evaluation to run' ) parser.add_argument( '--output_dir', default='./eval_output', help='Directory to save evaluation results' ) parser.add_argument( '--save_predictions', action='store_true', help='Save individual predictions and targets' ) parser.add_argument( '--batch_size', type=int, default=None, help='Override batch size for evaluation' ) parser.add_argument( '--sessions', nargs='+', default=None, help='Specific sessions to evaluate (overrides config)' ) args = parser.parse_args() # Setup TPU environment os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2') # Load configuration if not os.path.exists(args.config_path): raise FileNotFoundError(f"Configuration file not found: {args.config_path}") config = OmegaConf.load(args.config_path) # Apply overrides if args.data_dir: config.dataset.dataset_dir = args.data_dir if args.batch_size: config.dataset.batch_size = args.batch_size if args.sessions: config.dataset.sessions = args.sessions if args.output_dir: config.output_dir = args.output_dir # Validate model checkpoint exists if not os.path.exists(args.model_path + '.weights.h5'): raise FileNotFoundError(f"Model checkpoint not found: {args.model_path}") try: # Initialize evaluator evaluator = BrainToTextEvaluatorTF( model_path=args.model_path, config=config, eval_type=args.eval_type ) # Run evaluation results = evaluator.evaluate_dataset( save_results=True, return_predictions=args.save_predictions ) # Print results metrics = results['summary_metrics'] print("\n" + "="*60) print("EVALUATION RESULTS") print("="*60) print(f"Overall PER: {metrics['overall_per']:.6f}") print(f"Overall Loss: {metrics['overall_loss']:.6f}") print(f"Total Edit Distance: {metrics['total_edit_distance']}") print(f"Total Sequence Length: {metrics['total_seq_length']}") print(f"Total Trials: {metrics['total_trials']}") print(f"Batches Processed: {metrics['num_batches']}") # Print per-day results if available if 'day_metrics' in metrics and metrics['day_metrics']: print("\nPER-DAY RESULTS:") print("-" * 40) for day_idx, day_metrics in metrics['day_metrics'].items(): session_name = config.dataset.sessions[day_idx] if day_idx < len(config.dataset.sessions) else f"Day_{day_idx}" print(f"{session_name}: PER={day_metrics['per']:.6f}, Trials={day_metrics['trials']}") print("\nEvaluation completed successfully!") except Exception as e: print(f"Evaluation failed: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()