Files
b2txt25/model_training_nnn_tpu/evaluate_model_tf.py

480 lines
16 KiB
Python
Raw Permalink Normal View History

2025-10-15 16:55:52 +08:00
#!/usr/bin/env python3
"""
TensorFlow Evaluation Script for Brain-to-Text RNN Model
Optimized for TPU v5e-8
This script evaluates the TripleGRUDecoder model using TensorFlow and provides
detailed metrics and analysis of model performance on test data.
Usage:
python evaluate_model_tf.py --model_path path/to/model --data_dir path/to/data
Requirements:
- TensorFlow >= 2.15.0
- TPU v5e-8 environment
- Trained model checkpoint
- Access to brain-to-text HDF5 dataset
"""
import argparse
import os
import sys
import json
import pickle
import numpy as np
import tensorflow as tf
from typing import Dict, Any, List, Tuple
from omegaconf import OmegaConf
# Add the current directory to Python path for imports
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from trainer_tf import BrainToTextDecoderTrainerTF
from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn
from rnn_model_tf import create_tpu_strategy, configure_mixed_precision
class BrainToTextEvaluatorTF:
"""
TensorFlow evaluator for brain-to-text model performance analysis
"""
def __init__(self, model_path: str, config: Dict[str, Any], eval_type: str = 'test'):
"""
Initialize evaluator
Args:
model_path: Path to trained model checkpoint
config: Configuration dictionary
eval_type: 'test' or 'val' evaluation type
"""
self.model_path = model_path
self.config = config
self.eval_type = eval_type
# Initialize TPU strategy
self.strategy = create_tpu_strategy()
print(f"Evaluation using {self.strategy.num_replicas_in_sync} TPU cores")
# Configure mixed precision
if config.get('use_amp', True):
configure_mixed_precision()
# Load model
with self.strategy.scope():
self.trainer = BrainToTextDecoderTrainerTF(config)
self.trainer.load_checkpoint(model_path)
print(f"Model loaded from: {model_path}")
def evaluate_dataset(self, save_results: bool = True,
return_predictions: bool = False) -> Dict[str, Any]:
"""
Evaluate model on specified dataset
Args:
save_results: Whether to save detailed results to file
return_predictions: Whether to return individual predictions
Returns:
Dictionary containing evaluation metrics and optionally predictions
"""
print(f"Starting {self.eval_type} evaluation...")
# Create evaluation dataset
if self.eval_type == 'test':
dataset_tf = self.trainer.val_dataset_tf # Using validation data as test
else:
dataset_tf = self.trainer.val_dataset_tf
eval_dataset = create_input_fn(
dataset_tf,
self.config['dataset']['data_transforms'],
training=False
)
# Distribute dataset
eval_dist_dataset = self.strategy.experimental_distribute_dataset(eval_dataset)
# Run evaluation
results = self._run_evaluation(eval_dist_dataset, return_predictions)
# Calculate summary metrics
summary_metrics = self._calculate_summary_metrics(results)
print(f"Evaluation completed!")
print(f"Overall PER: {summary_metrics['overall_per']:.4f}")
print(f"Overall Loss: {summary_metrics['overall_loss']:.4f}")
print(f"Total trials evaluated: {summary_metrics['total_trials']}")
# Save results if requested
if save_results:
self._save_results(results, summary_metrics)
return {
'summary_metrics': summary_metrics,
'detailed_results': results if return_predictions else None
}
def _run_evaluation(self, eval_dataset, return_predictions: bool) -> List[Dict[str, Any]]:
"""Run evaluation on distributed dataset"""
all_results = []
batch_idx = 0
for batch in eval_dataset:
batch_results = self.strategy.run(self._evaluation_step, args=(batch, return_predictions))
# Gather results from all replicas
gathered_results = {}
for key in batch_results.keys():
if key in ['logits', 'features'] and not return_predictions:
continue # Skip large tensors if not needed
values = self.strategy.experimental_local_results(batch_results[key])
if key in ['loss', 'edit_distance', 'seq_length']:
# Scalar metrics - just take the values
gathered_results[key] = [float(v.numpy()) for v in values]
else:
# Tensor data - concatenate across replicas
gathered_results[key] = [v.numpy() for v in values]
all_results.append(gathered_results)
batch_idx += 1
if batch_idx % 10 == 0:
print(f"Processed {batch_idx} batches...")
return all_results
@tf.function
def _evaluation_step(self, batch, return_predictions: bool):
"""Single evaluation step"""
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indices = batch['day_indices']
# Apply data transformations (no augmentation)
from dataset_tf import DataAugmentationTF
features_transformed, n_time_steps_transformed = DataAugmentationTF.transform_data(
features, n_time_steps, self.config['dataset']['data_transforms'], training=False
)
# Calculate adjusted lengths for CTC
adjusted_lens = tf.cast(
(tf.cast(n_time_steps_transformed, tf.float32) - self.config['model']['patch_size']) /
self.config['model']['patch_stride'] + 1,
tf.int32
)
# Forward pass
logits = self.trainer.model(
features_transformed, day_indices, None, False, 'inference', training=False
)
# Calculate loss
loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
loss = self.trainer.ctc_loss(loss_input, logits)
loss = tf.reduce_mean(loss)
# Calculate edit distance for PER
predicted_ids = tf.argmax(logits, axis=-1)
batch_size = tf.shape(logits)[0]
# Initialize metrics
total_edit_distance = 0
total_seq_length = tf.reduce_sum(phone_seq_lens)
# Decode predictions and calculate edit distance
predictions = []
targets = []
for i in range(batch_size):
# Get prediction for this sample
pred_seq = predicted_ids[i, :adjusted_lens[i]]
# Remove consecutive duplicates using tf.py_function for simplicity
pred_seq_unique = tf.py_function(
func=self._remove_consecutive_duplicates,
inp=[pred_seq],
Tout=tf.int64
)
# Remove blanks (assuming blank_index=0)
pred_seq_clean = tf.boolean_mask(pred_seq_unique, pred_seq_unique != 0)
# Get true sequence
true_seq = labels[i, :phone_seq_lens[i]]
# Calculate edit distance for this pair
if tf.size(pred_seq_clean) > 0 and tf.size(true_seq) > 0:
pred_sparse = tf.SparseTensor(
indices=tf.expand_dims(tf.range(tf.size(pred_seq_clean), dtype=tf.int64), 1),
values=tf.cast(pred_seq_clean, tf.int64),
dense_shape=[tf.size(pred_seq_clean, out_type=tf.int64)]
)
true_sparse = tf.SparseTensor(
indices=tf.expand_dims(tf.range(tf.size(true_seq), dtype=tf.int64), 1),
values=tf.cast(true_seq, tf.int64),
dense_shape=[tf.size(true_seq, out_type=tf.int64)]
)
edit_dist = tf.edit_distance(pred_sparse, true_sparse, normalize=False)
total_edit_distance += edit_dist
if return_predictions:
predictions.append(pred_seq_clean)
targets.append(true_seq)
result = {
'loss': loss,
'edit_distance': total_edit_distance,
'seq_length': total_seq_length,
'day_indices': day_indices,
'n_time_steps': n_time_steps,
'phone_seq_lens': phone_seq_lens
}
if return_predictions:
result.update({
'logits': logits,
'predictions': predictions,
'targets': targets,
'features': features
})
return result
def _remove_consecutive_duplicates(self, seq):
"""Remove consecutive duplicate elements from sequence"""
seq_np = seq.numpy()
if len(seq_np) == 0:
return tf.constant([], dtype=tf.int64)
unique_seq = [seq_np[0]]
for i in range(1, len(seq_np)):
if seq_np[i] != seq_np[i-1]:
unique_seq.append(seq_np[i])
return tf.constant(unique_seq, dtype=tf.int64)
def _calculate_summary_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Calculate summary metrics from evaluation results"""
total_loss = 0.0
total_edit_distance = 0
total_seq_length = 0
total_trials = 0
num_batches = len(results)
# Day-specific metrics
day_metrics = {}
for batch_results in results:
# Sum losses across replicas
batch_loss = sum(batch_results['loss'])
total_loss += batch_loss
# Sum edit distances and sequence lengths
batch_edit_dist = sum(batch_results['edit_distance'])
batch_seq_len = sum(batch_results['seq_length'])
total_edit_distance += batch_edit_dist
total_seq_length += batch_seq_len
# Count trials
for day_indices_replica in batch_results['day_indices']:
total_trials += len(day_indices_replica)
# Track per-day metrics
for i, day_idx in enumerate(day_indices_replica):
day_idx = int(day_idx)
if day_idx not in day_metrics:
day_metrics[day_idx] = {'edit_distance': 0, 'seq_length': 0, 'trials': 0}
day_metrics[day_idx]['trials'] += 1
# Calculate averages
avg_loss = total_loss / max(num_batches, 1)
overall_per = total_edit_distance / max(total_seq_length, 1e-6)
# Calculate per-day PERs
day_pers = {}
for day_idx, metrics in day_metrics.items():
day_per = metrics['edit_distance'] / max(metrics['seq_length'], 1e-6)
day_pers[day_idx] = {
'per': day_per,
'edit_distance': metrics['edit_distance'],
'seq_length': metrics['seq_length'],
'trials': metrics['trials']
}
return {
'overall_per': float(overall_per),
'overall_loss': float(avg_loss),
'total_edit_distance': int(total_edit_distance),
'total_seq_length': int(total_seq_length),
'total_trials': total_trials,
'num_batches': num_batches,
'day_metrics': day_pers
}
def _save_results(self, detailed_results: List[Dict[str, Any]],
summary_metrics: Dict[str, Any]):
"""Save evaluation results to files"""
output_dir = self.config.get('output_dir', './eval_output')
os.makedirs(output_dir, exist_ok=True)
# Save summary metrics
summary_path = os.path.join(output_dir, f'{self.eval_type}_summary_metrics.json')
with open(summary_path, 'w') as f:
json.dump(summary_metrics, f, indent=2)
print(f"Summary metrics saved to: {summary_path}")
# Save detailed results
detailed_path = os.path.join(output_dir, f'{self.eval_type}_detailed_results.pkl')
with open(detailed_path, 'wb') as f:
pickle.dump(detailed_results, f)
print(f"Detailed results saved to: {detailed_path}")
# Save per-day breakdown
if 'day_metrics' in summary_metrics:
day_breakdown_path = os.path.join(output_dir, f'{self.eval_type}_day_breakdown.json')
with open(day_breakdown_path, 'w') as f:
json.dump(summary_metrics['day_metrics'], f, indent=2)
print(f"Per-day breakdown saved to: {day_breakdown_path}")
def main():
"""Main evaluation function"""
parser = argparse.ArgumentParser(
description='Evaluate Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'--model_path',
required=True,
help='Path to trained model checkpoint (without extension)'
)
parser.add_argument(
'--config_path',
default='rnn_args.yaml',
help='Path to model configuration file'
)
parser.add_argument(
'--data_dir',
default=None,
help='Override data directory from config'
)
parser.add_argument(
'--eval_type',
choices=['test', 'val'],
default='test',
help='Type of evaluation to run'
)
parser.add_argument(
'--output_dir',
default='./eval_output',
help='Directory to save evaluation results'
)
parser.add_argument(
'--save_predictions',
action='store_true',
help='Save individual predictions and targets'
)
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='Override batch size for evaluation'
)
parser.add_argument(
'--sessions',
nargs='+',
default=None,
help='Specific sessions to evaluate (overrides config)'
)
args = parser.parse_args()
# Setup TPU environment
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
# Load configuration
if not os.path.exists(args.config_path):
raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
config = OmegaConf.load(args.config_path)
# Apply overrides
if args.data_dir:
config.dataset.dataset_dir = args.data_dir
if args.batch_size:
config.dataset.batch_size = args.batch_size
if args.sessions:
config.dataset.sessions = args.sessions
if args.output_dir:
config.output_dir = args.output_dir
# Validate model checkpoint exists
if not os.path.exists(args.model_path + '.weights.h5'):
raise FileNotFoundError(f"Model checkpoint not found: {args.model_path}")
try:
# Initialize evaluator
evaluator = BrainToTextEvaluatorTF(
model_path=args.model_path,
config=config,
eval_type=args.eval_type
)
# Run evaluation
results = evaluator.evaluate_dataset(
save_results=True,
return_predictions=args.save_predictions
)
# Print results
metrics = results['summary_metrics']
print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(f"Overall PER: {metrics['overall_per']:.6f}")
print(f"Overall Loss: {metrics['overall_loss']:.6f}")
print(f"Total Edit Distance: {metrics['total_edit_distance']}")
print(f"Total Sequence Length: {metrics['total_seq_length']}")
print(f"Total Trials: {metrics['total_trials']}")
print(f"Batches Processed: {metrics['num_batches']}")
# Print per-day results if available
if 'day_metrics' in metrics and metrics['day_metrics']:
print("\nPER-DAY RESULTS:")
print("-" * 40)
for day_idx, day_metrics in metrics['day_metrics'].items():
session_name = config.dataset.sessions[day_idx] if day_idx < len(config.dataset.sessions) else f"Day_{day_idx}"
print(f"{session_name}: PER={day_metrics['per']:.6f}, Trials={day_metrics['trials']}")
print("\nEvaluation completed successfully!")
except Exception as e:
print(f"Evaluation failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()