480 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			480 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | TensorFlow Evaluation Script for Brain-to-Text RNN Model | ||
|  | Optimized for TPU v5e-8 | ||
|  | 
 | ||
|  | This script evaluates the TripleGRUDecoder model using TensorFlow and provides | ||
|  | detailed metrics and analysis of model performance on test data. | ||
|  | 
 | ||
|  | Usage: | ||
|  |     python evaluate_model_tf.py --model_path path/to/model --data_dir path/to/data | ||
|  | 
 | ||
|  | Requirements: | ||
|  |     - TensorFlow >= 2.15.0 | ||
|  |     - TPU v5e-8 environment | ||
|  |     - Trained model checkpoint | ||
|  |     - Access to brain-to-text HDF5 dataset | ||
|  | """
 | ||
|  | 
 | ||
|  | import argparse | ||
|  | import os | ||
|  | import sys | ||
|  | import json | ||
|  | import pickle | ||
|  | import numpy as np | ||
|  | import tensorflow as tf | ||
|  | from typing import Dict, Any, List, Tuple | ||
|  | from omegaconf import OmegaConf | ||
|  | 
 | ||
|  | # Add the current directory to Python path for imports | ||
|  | sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|  | 
 | ||
|  | from trainer_tf import BrainToTextDecoderTrainerTF | ||
|  | from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn | ||
|  | from rnn_model_tf import create_tpu_strategy, configure_mixed_precision | ||
|  | 
 | ||
|  | 
 | ||
|  | class BrainToTextEvaluatorTF: | ||
|  |     """
 | ||
|  |     TensorFlow evaluator for brain-to-text model performance analysis | ||
|  |     """
 | ||
|  | 
 | ||
|  |     def __init__(self, model_path: str, config: Dict[str, Any], eval_type: str = 'test'): | ||
|  |         """
 | ||
|  |         Initialize evaluator | ||
|  | 
 | ||
|  |         Args: | ||
|  |             model_path: Path to trained model checkpoint | ||
|  |             config: Configuration dictionary | ||
|  |             eval_type: 'test' or 'val' evaluation type | ||
|  |         """
 | ||
|  |         self.model_path = model_path | ||
|  |         self.config = config | ||
|  |         self.eval_type = eval_type | ||
|  | 
 | ||
|  |         # Initialize TPU strategy | ||
|  |         self.strategy = create_tpu_strategy() | ||
|  |         print(f"Evaluation using {self.strategy.num_replicas_in_sync} TPU cores") | ||
|  | 
 | ||
|  |         # Configure mixed precision | ||
|  |         if config.get('use_amp', True): | ||
|  |             configure_mixed_precision() | ||
|  | 
 | ||
|  |         # Load model | ||
|  |         with self.strategy.scope(): | ||
|  |             self.trainer = BrainToTextDecoderTrainerTF(config) | ||
|  |             self.trainer.load_checkpoint(model_path) | ||
|  | 
 | ||
|  |         print(f"Model loaded from: {model_path}") | ||
|  | 
 | ||
|  |     def evaluate_dataset(self, save_results: bool = True, | ||
|  |                         return_predictions: bool = False) -> Dict[str, Any]: | ||
|  |         """
 | ||
|  |         Evaluate model on specified dataset | ||
|  | 
 | ||
|  |         Args: | ||
|  |             save_results: Whether to save detailed results to file | ||
|  |             return_predictions: Whether to return individual predictions | ||
|  | 
 | ||
|  |         Returns: | ||
|  |             Dictionary containing evaluation metrics and optionally predictions | ||
|  |         """
 | ||
|  |         print(f"Starting {self.eval_type} evaluation...") | ||
|  | 
 | ||
|  |         # Create evaluation dataset | ||
|  |         if self.eval_type == 'test': | ||
|  |             dataset_tf = self.trainer.val_dataset_tf  # Using validation data as test | ||
|  |         else: | ||
|  |             dataset_tf = self.trainer.val_dataset_tf | ||
|  | 
 | ||
|  |         eval_dataset = create_input_fn( | ||
|  |             dataset_tf, | ||
|  |             self.config['dataset']['data_transforms'], | ||
|  |             training=False | ||
|  |         ) | ||
|  | 
 | ||
|  |         # Distribute dataset | ||
|  |         eval_dist_dataset = self.strategy.experimental_distribute_dataset(eval_dataset) | ||
|  | 
 | ||
|  |         # Run evaluation | ||
|  |         results = self._run_evaluation(eval_dist_dataset, return_predictions) | ||
|  | 
 | ||
|  |         # Calculate summary metrics | ||
|  |         summary_metrics = self._calculate_summary_metrics(results) | ||
|  | 
 | ||
|  |         print(f"Evaluation completed!") | ||
|  |         print(f"Overall PER: {summary_metrics['overall_per']:.4f}") | ||
|  |         print(f"Overall Loss: {summary_metrics['overall_loss']:.4f}") | ||
|  |         print(f"Total trials evaluated: {summary_metrics['total_trials']}") | ||
|  | 
 | ||
|  |         # Save results if requested | ||
|  |         if save_results: | ||
|  |             self._save_results(results, summary_metrics) | ||
|  | 
 | ||
|  |         return { | ||
|  |             'summary_metrics': summary_metrics, | ||
|  |             'detailed_results': results if return_predictions else None | ||
|  |         } | ||
|  | 
 | ||
|  |     def _run_evaluation(self, eval_dataset, return_predictions: bool) -> List[Dict[str, Any]]: | ||
|  |         """Run evaluation on distributed dataset""" | ||
|  |         all_results = [] | ||
|  |         batch_idx = 0 | ||
|  | 
 | ||
|  |         for batch in eval_dataset: | ||
|  |             batch_results = self.strategy.run(self._evaluation_step, args=(batch, return_predictions)) | ||
|  | 
 | ||
|  |             # Gather results from all replicas | ||
|  |             gathered_results = {} | ||
|  |             for key in batch_results.keys(): | ||
|  |                 if key in ['logits', 'features'] and not return_predictions: | ||
|  |                     continue  # Skip large tensors if not needed | ||
|  | 
 | ||
|  |                 values = self.strategy.experimental_local_results(batch_results[key]) | ||
|  |                 if key in ['loss', 'edit_distance', 'seq_length']: | ||
|  |                     # Scalar metrics - just take the values | ||
|  |                     gathered_results[key] = [float(v.numpy()) for v in values] | ||
|  |                 else: | ||
|  |                     # Tensor data - concatenate across replicas | ||
|  |                     gathered_results[key] = [v.numpy() for v in values] | ||
|  | 
 | ||
|  |             all_results.append(gathered_results) | ||
|  |             batch_idx += 1 | ||
|  | 
 | ||
|  |             if batch_idx % 10 == 0: | ||
|  |                 print(f"Processed {batch_idx} batches...") | ||
|  | 
 | ||
|  |         return all_results | ||
|  | 
 | ||
|  |     @tf.function | ||
|  |     def _evaluation_step(self, batch, return_predictions: bool): | ||
|  |         """Single evaluation step""" | ||
|  |         features = batch['input_features'] | ||
|  |         labels = batch['seq_class_ids'] | ||
|  |         n_time_steps = batch['n_time_steps'] | ||
|  |         phone_seq_lens = batch['phone_seq_lens'] | ||
|  |         day_indices = batch['day_indices'] | ||
|  | 
 | ||
|  |         # Apply data transformations (no augmentation) | ||
|  |         from dataset_tf import DataAugmentationTF | ||
|  |         features_transformed, n_time_steps_transformed = DataAugmentationTF.transform_data( | ||
|  |             features, n_time_steps, self.config['dataset']['data_transforms'], training=False | ||
|  |         ) | ||
|  | 
 | ||
|  |         # Calculate adjusted lengths for CTC | ||
|  |         adjusted_lens = tf.cast( | ||
|  |             (tf.cast(n_time_steps_transformed, tf.float32) - self.config['model']['patch_size']) / | ||
|  |             self.config['model']['patch_stride'] + 1, | ||
|  |             tf.int32 | ||
|  |         ) | ||
|  | 
 | ||
|  |         # Forward pass | ||
|  |         logits = self.trainer.model( | ||
|  |             features_transformed, day_indices, None, False, 'inference', training=False | ||
|  |         ) | ||
|  | 
 | ||
|  |         # Calculate loss | ||
|  |         loss_input = { | ||
|  |             'labels': labels, | ||
|  |             'input_lengths': adjusted_lens, | ||
|  |             'label_lengths': phone_seq_lens | ||
|  |         } | ||
|  |         loss = self.trainer.ctc_loss(loss_input, logits) | ||
|  |         loss = tf.reduce_mean(loss) | ||
|  | 
 | ||
|  |         # Calculate edit distance for PER | ||
|  |         predicted_ids = tf.argmax(logits, axis=-1) | ||
|  |         batch_size = tf.shape(logits)[0] | ||
|  | 
 | ||
|  |         # Initialize metrics | ||
|  |         total_edit_distance = 0 | ||
|  |         total_seq_length = tf.reduce_sum(phone_seq_lens) | ||
|  | 
 | ||
|  |         # Decode predictions and calculate edit distance | ||
|  |         predictions = [] | ||
|  |         targets = [] | ||
|  | 
 | ||
|  |         for i in range(batch_size): | ||
|  |             # Get prediction for this sample | ||
|  |             pred_seq = predicted_ids[i, :adjusted_lens[i]] | ||
|  | 
 | ||
|  |             # Remove consecutive duplicates using tf.py_function for simplicity | ||
|  |             pred_seq_unique = tf.py_function( | ||
|  |                 func=self._remove_consecutive_duplicates, | ||
|  |                 inp=[pred_seq], | ||
|  |                 Tout=tf.int64 | ||
|  |             ) | ||
|  | 
 | ||
|  |             # Remove blanks (assuming blank_index=0) | ||
|  |             pred_seq_clean = tf.boolean_mask(pred_seq_unique, pred_seq_unique != 0) | ||
|  | 
 | ||
|  |             # Get true sequence | ||
|  |             true_seq = labels[i, :phone_seq_lens[i]] | ||
|  | 
 | ||
|  |             # Calculate edit distance for this pair | ||
|  |             if tf.size(pred_seq_clean) > 0 and tf.size(true_seq) > 0: | ||
|  |                 pred_sparse = tf.SparseTensor( | ||
|  |                     indices=tf.expand_dims(tf.range(tf.size(pred_seq_clean), dtype=tf.int64), 1), | ||
|  |                     values=tf.cast(pred_seq_clean, tf.int64), | ||
|  |                     dense_shape=[tf.size(pred_seq_clean, out_type=tf.int64)] | ||
|  |                 ) | ||
|  | 
 | ||
|  |                 true_sparse = tf.SparseTensor( | ||
|  |                     indices=tf.expand_dims(tf.range(tf.size(true_seq), dtype=tf.int64), 1), | ||
|  |                     values=tf.cast(true_seq, tf.int64), | ||
|  |                     dense_shape=[tf.size(true_seq, out_type=tf.int64)] | ||
|  |                 ) | ||
|  | 
 | ||
|  |                 edit_dist = tf.edit_distance(pred_sparse, true_sparse, normalize=False) | ||
|  |                 total_edit_distance += edit_dist | ||
|  | 
 | ||
|  |             if return_predictions: | ||
|  |                 predictions.append(pred_seq_clean) | ||
|  |                 targets.append(true_seq) | ||
|  | 
 | ||
|  |         result = { | ||
|  |             'loss': loss, | ||
|  |             'edit_distance': total_edit_distance, | ||
|  |             'seq_length': total_seq_length, | ||
|  |             'day_indices': day_indices, | ||
|  |             'n_time_steps': n_time_steps, | ||
|  |             'phone_seq_lens': phone_seq_lens | ||
|  |         } | ||
|  | 
 | ||
|  |         if return_predictions: | ||
|  |             result.update({ | ||
|  |                 'logits': logits, | ||
|  |                 'predictions': predictions, | ||
|  |                 'targets': targets, | ||
|  |                 'features': features | ||
|  |             }) | ||
|  | 
 | ||
|  |         return result | ||
|  | 
 | ||
|  |     def _remove_consecutive_duplicates(self, seq): | ||
|  |         """Remove consecutive duplicate elements from sequence""" | ||
|  |         seq_np = seq.numpy() | ||
|  |         if len(seq_np) == 0: | ||
|  |             return tf.constant([], dtype=tf.int64) | ||
|  | 
 | ||
|  |         unique_seq = [seq_np[0]] | ||
|  |         for i in range(1, len(seq_np)): | ||
|  |             if seq_np[i] != seq_np[i-1]: | ||
|  |                 unique_seq.append(seq_np[i]) | ||
|  | 
 | ||
|  |         return tf.constant(unique_seq, dtype=tf.int64) | ||
|  | 
 | ||
|  |     def _calculate_summary_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: | ||
|  |         """Calculate summary metrics from evaluation results""" | ||
|  |         total_loss = 0.0 | ||
|  |         total_edit_distance = 0 | ||
|  |         total_seq_length = 0 | ||
|  |         total_trials = 0 | ||
|  |         num_batches = len(results) | ||
|  | 
 | ||
|  |         # Day-specific metrics | ||
|  |         day_metrics = {} | ||
|  | 
 | ||
|  |         for batch_results in results: | ||
|  |             # Sum losses across replicas | ||
|  |             batch_loss = sum(batch_results['loss']) | ||
|  |             total_loss += batch_loss | ||
|  | 
 | ||
|  |             # Sum edit distances and sequence lengths | ||
|  |             batch_edit_dist = sum(batch_results['edit_distance']) | ||
|  |             batch_seq_len = sum(batch_results['seq_length']) | ||
|  | 
 | ||
|  |             total_edit_distance += batch_edit_dist | ||
|  |             total_seq_length += batch_seq_len | ||
|  | 
 | ||
|  |             # Count trials | ||
|  |             for day_indices_replica in batch_results['day_indices']: | ||
|  |                 total_trials += len(day_indices_replica) | ||
|  | 
 | ||
|  |                 # Track per-day metrics | ||
|  |                 for i, day_idx in enumerate(day_indices_replica): | ||
|  |                     day_idx = int(day_idx) | ||
|  |                     if day_idx not in day_metrics: | ||
|  |                         day_metrics[day_idx] = {'edit_distance': 0, 'seq_length': 0, 'trials': 0} | ||
|  | 
 | ||
|  |                     day_metrics[day_idx]['trials'] += 1 | ||
|  | 
 | ||
|  |         # Calculate averages | ||
|  |         avg_loss = total_loss / max(num_batches, 1) | ||
|  |         overall_per = total_edit_distance / max(total_seq_length, 1e-6) | ||
|  | 
 | ||
|  |         # Calculate per-day PERs | ||
|  |         day_pers = {} | ||
|  |         for day_idx, metrics in day_metrics.items(): | ||
|  |             day_per = metrics['edit_distance'] / max(metrics['seq_length'], 1e-6) | ||
|  |             day_pers[day_idx] = { | ||
|  |                 'per': day_per, | ||
|  |                 'edit_distance': metrics['edit_distance'], | ||
|  |                 'seq_length': metrics['seq_length'], | ||
|  |                 'trials': metrics['trials'] | ||
|  |             } | ||
|  | 
 | ||
|  |         return { | ||
|  |             'overall_per': float(overall_per), | ||
|  |             'overall_loss': float(avg_loss), | ||
|  |             'total_edit_distance': int(total_edit_distance), | ||
|  |             'total_seq_length': int(total_seq_length), | ||
|  |             'total_trials': total_trials, | ||
|  |             'num_batches': num_batches, | ||
|  |             'day_metrics': day_pers | ||
|  |         } | ||
|  | 
 | ||
|  |     def _save_results(self, detailed_results: List[Dict[str, Any]], | ||
|  |                      summary_metrics: Dict[str, Any]): | ||
|  |         """Save evaluation results to files""" | ||
|  |         output_dir = self.config.get('output_dir', './eval_output') | ||
|  |         os.makedirs(output_dir, exist_ok=True) | ||
|  | 
 | ||
|  |         # Save summary metrics | ||
|  |         summary_path = os.path.join(output_dir, f'{self.eval_type}_summary_metrics.json') | ||
|  |         with open(summary_path, 'w') as f: | ||
|  |             json.dump(summary_metrics, f, indent=2) | ||
|  |         print(f"Summary metrics saved to: {summary_path}") | ||
|  | 
 | ||
|  |         # Save detailed results | ||
|  |         detailed_path = os.path.join(output_dir, f'{self.eval_type}_detailed_results.pkl') | ||
|  |         with open(detailed_path, 'wb') as f: | ||
|  |             pickle.dump(detailed_results, f) | ||
|  |         print(f"Detailed results saved to: {detailed_path}") | ||
|  | 
 | ||
|  |         # Save per-day breakdown | ||
|  |         if 'day_metrics' in summary_metrics: | ||
|  |             day_breakdown_path = os.path.join(output_dir, f'{self.eval_type}_day_breakdown.json') | ||
|  |             with open(day_breakdown_path, 'w') as f: | ||
|  |                 json.dump(summary_metrics['day_metrics'], f, indent=2) | ||
|  |             print(f"Per-day breakdown saved to: {day_breakdown_path}") | ||
|  | 
 | ||
|  | 
 | ||
|  | def main(): | ||
|  |     """Main evaluation function""" | ||
|  |     parser = argparse.ArgumentParser( | ||
|  |         description='Evaluate Brain-to-Text RNN Model with TensorFlow on TPU v5e-8', | ||
|  |         formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--model_path', | ||
|  |         required=True, | ||
|  |         help='Path to trained model checkpoint (without extension)' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--config_path', | ||
|  |         default='rnn_args.yaml', | ||
|  |         help='Path to model configuration file' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--data_dir', | ||
|  |         default=None, | ||
|  |         help='Override data directory from config' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--eval_type', | ||
|  |         choices=['test', 'val'], | ||
|  |         default='test', | ||
|  |         help='Type of evaluation to run' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--output_dir', | ||
|  |         default='./eval_output', | ||
|  |         help='Directory to save evaluation results' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--save_predictions', | ||
|  |         action='store_true', | ||
|  |         help='Save individual predictions and targets' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--batch_size', | ||
|  |         type=int, | ||
|  |         default=None, | ||
|  |         help='Override batch size for evaluation' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--sessions', | ||
|  |         nargs='+', | ||
|  |         default=None, | ||
|  |         help='Specific sessions to evaluate (overrides config)' | ||
|  |     ) | ||
|  | 
 | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     # Setup TPU environment | ||
|  |     os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2') | ||
|  | 
 | ||
|  |     # Load configuration | ||
|  |     if not os.path.exists(args.config_path): | ||
|  |         raise FileNotFoundError(f"Configuration file not found: {args.config_path}") | ||
|  | 
 | ||
|  |     config = OmegaConf.load(args.config_path) | ||
|  | 
 | ||
|  |     # Apply overrides | ||
|  |     if args.data_dir: | ||
|  |         config.dataset.dataset_dir = args.data_dir | ||
|  |     if args.batch_size: | ||
|  |         config.dataset.batch_size = args.batch_size | ||
|  |     if args.sessions: | ||
|  |         config.dataset.sessions = args.sessions | ||
|  |     if args.output_dir: | ||
|  |         config.output_dir = args.output_dir | ||
|  | 
 | ||
|  |     # Validate model checkpoint exists | ||
|  |     if not os.path.exists(args.model_path + '.weights.h5'): | ||
|  |         raise FileNotFoundError(f"Model checkpoint not found: {args.model_path}") | ||
|  | 
 | ||
|  |     try: | ||
|  |         # Initialize evaluator | ||
|  |         evaluator = BrainToTextEvaluatorTF( | ||
|  |             model_path=args.model_path, | ||
|  |             config=config, | ||
|  |             eval_type=args.eval_type | ||
|  |         ) | ||
|  | 
 | ||
|  |         # Run evaluation | ||
|  |         results = evaluator.evaluate_dataset( | ||
|  |             save_results=True, | ||
|  |             return_predictions=args.save_predictions | ||
|  |         ) | ||
|  | 
 | ||
|  |         # Print results | ||
|  |         metrics = results['summary_metrics'] | ||
|  |         print("\n" + "="*60) | ||
|  |         print("EVALUATION RESULTS") | ||
|  |         print("="*60) | ||
|  |         print(f"Overall PER: {metrics['overall_per']:.6f}") | ||
|  |         print(f"Overall Loss: {metrics['overall_loss']:.6f}") | ||
|  |         print(f"Total Edit Distance: {metrics['total_edit_distance']}") | ||
|  |         print(f"Total Sequence Length: {metrics['total_seq_length']}") | ||
|  |         print(f"Total Trials: {metrics['total_trials']}") | ||
|  |         print(f"Batches Processed: {metrics['num_batches']}") | ||
|  | 
 | ||
|  |         # Print per-day results if available | ||
|  |         if 'day_metrics' in metrics and metrics['day_metrics']: | ||
|  |             print("\nPER-DAY RESULTS:") | ||
|  |             print("-" * 40) | ||
|  |             for day_idx, day_metrics in metrics['day_metrics'].items(): | ||
|  |                 session_name = config.dataset.sessions[day_idx] if day_idx < len(config.dataset.sessions) else f"Day_{day_idx}" | ||
|  |                 print(f"{session_name}: PER={day_metrics['per']:.6f}, Trials={day_metrics['trials']}") | ||
|  | 
 | ||
|  |         print("\nEvaluation completed successfully!") | ||
|  | 
 | ||
|  |     except Exception as e: | ||
|  |         print(f"Evaluation failed: {e}") | ||
|  |         import traceback | ||
|  |         traceback.print_exc() | ||
|  |         sys.exit(1) | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     main() |