480 lines
16 KiB
Python
480 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
TensorFlow Evaluation Script for Brain-to-Text RNN Model
|
|
Optimized for TPU v5e-8
|
|
|
|
This script evaluates the TripleGRUDecoder model using TensorFlow and provides
|
|
detailed metrics and analysis of model performance on test data.
|
|
|
|
Usage:
|
|
python evaluate_model_tf.py --model_path path/to/model --data_dir path/to/data
|
|
|
|
Requirements:
|
|
- TensorFlow >= 2.15.0
|
|
- TPU v5e-8 environment
|
|
- Trained model checkpoint
|
|
- Access to brain-to-text HDF5 dataset
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import json
|
|
import pickle
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from typing import Dict, Any, List, Tuple
|
|
from omegaconf import OmegaConf
|
|
|
|
# Add the current directory to Python path for imports
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from trainer_tf import BrainToTextDecoderTrainerTF
|
|
from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn
|
|
from rnn_model_tf import create_tpu_strategy, configure_mixed_precision
|
|
|
|
|
|
class BrainToTextEvaluatorTF:
|
|
"""
|
|
TensorFlow evaluator for brain-to-text model performance analysis
|
|
"""
|
|
|
|
def __init__(self, model_path: str, config: Dict[str, Any], eval_type: str = 'test'):
|
|
"""
|
|
Initialize evaluator
|
|
|
|
Args:
|
|
model_path: Path to trained model checkpoint
|
|
config: Configuration dictionary
|
|
eval_type: 'test' or 'val' evaluation type
|
|
"""
|
|
self.model_path = model_path
|
|
self.config = config
|
|
self.eval_type = eval_type
|
|
|
|
# Initialize TPU strategy
|
|
self.strategy = create_tpu_strategy()
|
|
print(f"Evaluation using {self.strategy.num_replicas_in_sync} TPU cores")
|
|
|
|
# Configure mixed precision
|
|
if config.get('use_amp', True):
|
|
configure_mixed_precision()
|
|
|
|
# Load model
|
|
with self.strategy.scope():
|
|
self.trainer = BrainToTextDecoderTrainerTF(config)
|
|
self.trainer.load_checkpoint(model_path)
|
|
|
|
print(f"Model loaded from: {model_path}")
|
|
|
|
def evaluate_dataset(self, save_results: bool = True,
|
|
return_predictions: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Evaluate model on specified dataset
|
|
|
|
Args:
|
|
save_results: Whether to save detailed results to file
|
|
return_predictions: Whether to return individual predictions
|
|
|
|
Returns:
|
|
Dictionary containing evaluation metrics and optionally predictions
|
|
"""
|
|
print(f"Starting {self.eval_type} evaluation...")
|
|
|
|
# Create evaluation dataset
|
|
if self.eval_type == 'test':
|
|
dataset_tf = self.trainer.val_dataset_tf # Using validation data as test
|
|
else:
|
|
dataset_tf = self.trainer.val_dataset_tf
|
|
|
|
eval_dataset = create_input_fn(
|
|
dataset_tf,
|
|
self.config['dataset']['data_transforms'],
|
|
training=False
|
|
)
|
|
|
|
# Distribute dataset
|
|
eval_dist_dataset = self.strategy.experimental_distribute_dataset(eval_dataset)
|
|
|
|
# Run evaluation
|
|
results = self._run_evaluation(eval_dist_dataset, return_predictions)
|
|
|
|
# Calculate summary metrics
|
|
summary_metrics = self._calculate_summary_metrics(results)
|
|
|
|
print(f"Evaluation completed!")
|
|
print(f"Overall PER: {summary_metrics['overall_per']:.4f}")
|
|
print(f"Overall Loss: {summary_metrics['overall_loss']:.4f}")
|
|
print(f"Total trials evaluated: {summary_metrics['total_trials']}")
|
|
|
|
# Save results if requested
|
|
if save_results:
|
|
self._save_results(results, summary_metrics)
|
|
|
|
return {
|
|
'summary_metrics': summary_metrics,
|
|
'detailed_results': results if return_predictions else None
|
|
}
|
|
|
|
def _run_evaluation(self, eval_dataset, return_predictions: bool) -> List[Dict[str, Any]]:
|
|
"""Run evaluation on distributed dataset"""
|
|
all_results = []
|
|
batch_idx = 0
|
|
|
|
for batch in eval_dataset:
|
|
batch_results = self.strategy.run(self._evaluation_step, args=(batch, return_predictions))
|
|
|
|
# Gather results from all replicas
|
|
gathered_results = {}
|
|
for key in batch_results.keys():
|
|
if key in ['logits', 'features'] and not return_predictions:
|
|
continue # Skip large tensors if not needed
|
|
|
|
values = self.strategy.experimental_local_results(batch_results[key])
|
|
if key in ['loss', 'edit_distance', 'seq_length']:
|
|
# Scalar metrics - just take the values
|
|
gathered_results[key] = [float(v.numpy()) for v in values]
|
|
else:
|
|
# Tensor data - concatenate across replicas
|
|
gathered_results[key] = [v.numpy() for v in values]
|
|
|
|
all_results.append(gathered_results)
|
|
batch_idx += 1
|
|
|
|
if batch_idx % 10 == 0:
|
|
print(f"Processed {batch_idx} batches...")
|
|
|
|
return all_results
|
|
|
|
@tf.function
|
|
def _evaluation_step(self, batch, return_predictions: bool):
|
|
"""Single evaluation step"""
|
|
features = batch['input_features']
|
|
labels = batch['seq_class_ids']
|
|
n_time_steps = batch['n_time_steps']
|
|
phone_seq_lens = batch['phone_seq_lens']
|
|
day_indices = batch['day_indices']
|
|
|
|
# Apply data transformations (no augmentation)
|
|
from dataset_tf import DataAugmentationTF
|
|
features_transformed, n_time_steps_transformed = DataAugmentationTF.transform_data(
|
|
features, n_time_steps, self.config['dataset']['data_transforms'], training=False
|
|
)
|
|
|
|
# Calculate adjusted lengths for CTC
|
|
adjusted_lens = tf.cast(
|
|
(tf.cast(n_time_steps_transformed, tf.float32) - self.config['model']['patch_size']) /
|
|
self.config['model']['patch_stride'] + 1,
|
|
tf.int32
|
|
)
|
|
|
|
# Forward pass
|
|
logits = self.trainer.model(
|
|
features_transformed, day_indices, None, False, 'inference', training=False
|
|
)
|
|
|
|
# Calculate loss
|
|
loss_input = {
|
|
'labels': labels,
|
|
'input_lengths': adjusted_lens,
|
|
'label_lengths': phone_seq_lens
|
|
}
|
|
loss = self.trainer.ctc_loss(loss_input, logits)
|
|
loss = tf.reduce_mean(loss)
|
|
|
|
# Calculate edit distance for PER
|
|
predicted_ids = tf.argmax(logits, axis=-1)
|
|
batch_size = tf.shape(logits)[0]
|
|
|
|
# Initialize metrics
|
|
total_edit_distance = 0
|
|
total_seq_length = tf.reduce_sum(phone_seq_lens)
|
|
|
|
# Decode predictions and calculate edit distance
|
|
predictions = []
|
|
targets = []
|
|
|
|
for i in range(batch_size):
|
|
# Get prediction for this sample
|
|
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
|
|
|
# Remove consecutive duplicates using tf.py_function for simplicity
|
|
pred_seq_unique = tf.py_function(
|
|
func=self._remove_consecutive_duplicates,
|
|
inp=[pred_seq],
|
|
Tout=tf.int64
|
|
)
|
|
|
|
# Remove blanks (assuming blank_index=0)
|
|
pred_seq_clean = tf.boolean_mask(pred_seq_unique, pred_seq_unique != 0)
|
|
|
|
# Get true sequence
|
|
true_seq = labels[i, :phone_seq_lens[i]]
|
|
|
|
# Calculate edit distance for this pair
|
|
if tf.size(pred_seq_clean) > 0 and tf.size(true_seq) > 0:
|
|
pred_sparse = tf.SparseTensor(
|
|
indices=tf.expand_dims(tf.range(tf.size(pred_seq_clean), dtype=tf.int64), 1),
|
|
values=tf.cast(pred_seq_clean, tf.int64),
|
|
dense_shape=[tf.size(pred_seq_clean, out_type=tf.int64)]
|
|
)
|
|
|
|
true_sparse = tf.SparseTensor(
|
|
indices=tf.expand_dims(tf.range(tf.size(true_seq), dtype=tf.int64), 1),
|
|
values=tf.cast(true_seq, tf.int64),
|
|
dense_shape=[tf.size(true_seq, out_type=tf.int64)]
|
|
)
|
|
|
|
edit_dist = tf.edit_distance(pred_sparse, true_sparse, normalize=False)
|
|
total_edit_distance += edit_dist
|
|
|
|
if return_predictions:
|
|
predictions.append(pred_seq_clean)
|
|
targets.append(true_seq)
|
|
|
|
result = {
|
|
'loss': loss,
|
|
'edit_distance': total_edit_distance,
|
|
'seq_length': total_seq_length,
|
|
'day_indices': day_indices,
|
|
'n_time_steps': n_time_steps,
|
|
'phone_seq_lens': phone_seq_lens
|
|
}
|
|
|
|
if return_predictions:
|
|
result.update({
|
|
'logits': logits,
|
|
'predictions': predictions,
|
|
'targets': targets,
|
|
'features': features
|
|
})
|
|
|
|
return result
|
|
|
|
def _remove_consecutive_duplicates(self, seq):
|
|
"""Remove consecutive duplicate elements from sequence"""
|
|
seq_np = seq.numpy()
|
|
if len(seq_np) == 0:
|
|
return tf.constant([], dtype=tf.int64)
|
|
|
|
unique_seq = [seq_np[0]]
|
|
for i in range(1, len(seq_np)):
|
|
if seq_np[i] != seq_np[i-1]:
|
|
unique_seq.append(seq_np[i])
|
|
|
|
return tf.constant(unique_seq, dtype=tf.int64)
|
|
|
|
def _calculate_summary_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""Calculate summary metrics from evaluation results"""
|
|
total_loss = 0.0
|
|
total_edit_distance = 0
|
|
total_seq_length = 0
|
|
total_trials = 0
|
|
num_batches = len(results)
|
|
|
|
# Day-specific metrics
|
|
day_metrics = {}
|
|
|
|
for batch_results in results:
|
|
# Sum losses across replicas
|
|
batch_loss = sum(batch_results['loss'])
|
|
total_loss += batch_loss
|
|
|
|
# Sum edit distances and sequence lengths
|
|
batch_edit_dist = sum(batch_results['edit_distance'])
|
|
batch_seq_len = sum(batch_results['seq_length'])
|
|
|
|
total_edit_distance += batch_edit_dist
|
|
total_seq_length += batch_seq_len
|
|
|
|
# Count trials
|
|
for day_indices_replica in batch_results['day_indices']:
|
|
total_trials += len(day_indices_replica)
|
|
|
|
# Track per-day metrics
|
|
for i, day_idx in enumerate(day_indices_replica):
|
|
day_idx = int(day_idx)
|
|
if day_idx not in day_metrics:
|
|
day_metrics[day_idx] = {'edit_distance': 0, 'seq_length': 0, 'trials': 0}
|
|
|
|
day_metrics[day_idx]['trials'] += 1
|
|
|
|
# Calculate averages
|
|
avg_loss = total_loss / max(num_batches, 1)
|
|
overall_per = total_edit_distance / max(total_seq_length, 1e-6)
|
|
|
|
# Calculate per-day PERs
|
|
day_pers = {}
|
|
for day_idx, metrics in day_metrics.items():
|
|
day_per = metrics['edit_distance'] / max(metrics['seq_length'], 1e-6)
|
|
day_pers[day_idx] = {
|
|
'per': day_per,
|
|
'edit_distance': metrics['edit_distance'],
|
|
'seq_length': metrics['seq_length'],
|
|
'trials': metrics['trials']
|
|
}
|
|
|
|
return {
|
|
'overall_per': float(overall_per),
|
|
'overall_loss': float(avg_loss),
|
|
'total_edit_distance': int(total_edit_distance),
|
|
'total_seq_length': int(total_seq_length),
|
|
'total_trials': total_trials,
|
|
'num_batches': num_batches,
|
|
'day_metrics': day_pers
|
|
}
|
|
|
|
def _save_results(self, detailed_results: List[Dict[str, Any]],
|
|
summary_metrics: Dict[str, Any]):
|
|
"""Save evaluation results to files"""
|
|
output_dir = self.config.get('output_dir', './eval_output')
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Save summary metrics
|
|
summary_path = os.path.join(output_dir, f'{self.eval_type}_summary_metrics.json')
|
|
with open(summary_path, 'w') as f:
|
|
json.dump(summary_metrics, f, indent=2)
|
|
print(f"Summary metrics saved to: {summary_path}")
|
|
|
|
# Save detailed results
|
|
detailed_path = os.path.join(output_dir, f'{self.eval_type}_detailed_results.pkl')
|
|
with open(detailed_path, 'wb') as f:
|
|
pickle.dump(detailed_results, f)
|
|
print(f"Detailed results saved to: {detailed_path}")
|
|
|
|
# Save per-day breakdown
|
|
if 'day_metrics' in summary_metrics:
|
|
day_breakdown_path = os.path.join(output_dir, f'{self.eval_type}_day_breakdown.json')
|
|
with open(day_breakdown_path, 'w') as f:
|
|
json.dump(summary_metrics['day_metrics'], f, indent=2)
|
|
print(f"Per-day breakdown saved to: {day_breakdown_path}")
|
|
|
|
|
|
def main():
|
|
"""Main evaluation function"""
|
|
parser = argparse.ArgumentParser(
|
|
description='Evaluate Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--model_path',
|
|
required=True,
|
|
help='Path to trained model checkpoint (without extension)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--config_path',
|
|
default='rnn_args.yaml',
|
|
help='Path to model configuration file'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--data_dir',
|
|
default=None,
|
|
help='Override data directory from config'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--eval_type',
|
|
choices=['test', 'val'],
|
|
default='test',
|
|
help='Type of evaluation to run'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--output_dir',
|
|
default='./eval_output',
|
|
help='Directory to save evaluation results'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--save_predictions',
|
|
action='store_true',
|
|
help='Save individual predictions and targets'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--batch_size',
|
|
type=int,
|
|
default=None,
|
|
help='Override batch size for evaluation'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--sessions',
|
|
nargs='+',
|
|
default=None,
|
|
help='Specific sessions to evaluate (overrides config)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup TPU environment
|
|
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
|
|
|
|
# Load configuration
|
|
if not os.path.exists(args.config_path):
|
|
raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
|
|
|
|
config = OmegaConf.load(args.config_path)
|
|
|
|
# Apply overrides
|
|
if args.data_dir:
|
|
config.dataset.dataset_dir = args.data_dir
|
|
if args.batch_size:
|
|
config.dataset.batch_size = args.batch_size
|
|
if args.sessions:
|
|
config.dataset.sessions = args.sessions
|
|
if args.output_dir:
|
|
config.output_dir = args.output_dir
|
|
|
|
# Validate model checkpoint exists
|
|
if not os.path.exists(args.model_path + '.weights.h5'):
|
|
raise FileNotFoundError(f"Model checkpoint not found: {args.model_path}")
|
|
|
|
try:
|
|
# Initialize evaluator
|
|
evaluator = BrainToTextEvaluatorTF(
|
|
model_path=args.model_path,
|
|
config=config,
|
|
eval_type=args.eval_type
|
|
)
|
|
|
|
# Run evaluation
|
|
results = evaluator.evaluate_dataset(
|
|
save_results=True,
|
|
return_predictions=args.save_predictions
|
|
)
|
|
|
|
# Print results
|
|
metrics = results['summary_metrics']
|
|
print("\n" + "="*60)
|
|
print("EVALUATION RESULTS")
|
|
print("="*60)
|
|
print(f"Overall PER: {metrics['overall_per']:.6f}")
|
|
print(f"Overall Loss: {metrics['overall_loss']:.6f}")
|
|
print(f"Total Edit Distance: {metrics['total_edit_distance']}")
|
|
print(f"Total Sequence Length: {metrics['total_seq_length']}")
|
|
print(f"Total Trials: {metrics['total_trials']}")
|
|
print(f"Batches Processed: {metrics['num_batches']}")
|
|
|
|
# Print per-day results if available
|
|
if 'day_metrics' in metrics and metrics['day_metrics']:
|
|
print("\nPER-DAY RESULTS:")
|
|
print("-" * 40)
|
|
for day_idx, day_metrics in metrics['day_metrics'].items():
|
|
session_name = config.dataset.sessions[day_idx] if day_idx < len(config.dataset.sessions) else f"Day_{day_idx}"
|
|
print(f"{session_name}: PER={day_metrics['per']:.6f}, Trials={day_metrics['trials']}")
|
|
|
|
print("\nEvaluation completed successfully!")
|
|
|
|
except Exception as e:
|
|
print(f"Evaluation failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |