TPU
This commit is contained in:
480
model_training_nnn_tpu/evaluate_model_tf.py
Normal file
480
model_training_nnn_tpu/evaluate_model_tf.py
Normal file
@@ -0,0 +1,480 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TensorFlow Evaluation Script for Brain-to-Text RNN Model
|
||||
Optimized for TPU v5e-8
|
||||
|
||||
This script evaluates the TripleGRUDecoder model using TensorFlow and provides
|
||||
detailed metrics and analysis of model performance on test data.
|
||||
|
||||
Usage:
|
||||
python evaluate_model_tf.py --model_path path/to/model --data_dir path/to/data
|
||||
|
||||
Requirements:
|
||||
- TensorFlow >= 2.15.0
|
||||
- TPU v5e-8 environment
|
||||
- Trained model checkpoint
|
||||
- Access to brain-to-text HDF5 dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# Add the current directory to Python path for imports
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from trainer_tf import BrainToTextDecoderTrainerTF
|
||||
from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn
|
||||
from rnn_model_tf import create_tpu_strategy, configure_mixed_precision
|
||||
|
||||
|
||||
class BrainToTextEvaluatorTF:
|
||||
"""
|
||||
TensorFlow evaluator for brain-to-text model performance analysis
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, config: Dict[str, Any], eval_type: str = 'test'):
|
||||
"""
|
||||
Initialize evaluator
|
||||
|
||||
Args:
|
||||
model_path: Path to trained model checkpoint
|
||||
config: Configuration dictionary
|
||||
eval_type: 'test' or 'val' evaluation type
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.config = config
|
||||
self.eval_type = eval_type
|
||||
|
||||
# Initialize TPU strategy
|
||||
self.strategy = create_tpu_strategy()
|
||||
print(f"Evaluation using {self.strategy.num_replicas_in_sync} TPU cores")
|
||||
|
||||
# Configure mixed precision
|
||||
if config.get('use_amp', True):
|
||||
configure_mixed_precision()
|
||||
|
||||
# Load model
|
||||
with self.strategy.scope():
|
||||
self.trainer = BrainToTextDecoderTrainerTF(config)
|
||||
self.trainer.load_checkpoint(model_path)
|
||||
|
||||
print(f"Model loaded from: {model_path}")
|
||||
|
||||
def evaluate_dataset(self, save_results: bool = True,
|
||||
return_predictions: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate model on specified dataset
|
||||
|
||||
Args:
|
||||
save_results: Whether to save detailed results to file
|
||||
return_predictions: Whether to return individual predictions
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation metrics and optionally predictions
|
||||
"""
|
||||
print(f"Starting {self.eval_type} evaluation...")
|
||||
|
||||
# Create evaluation dataset
|
||||
if self.eval_type == 'test':
|
||||
dataset_tf = self.trainer.val_dataset_tf # Using validation data as test
|
||||
else:
|
||||
dataset_tf = self.trainer.val_dataset_tf
|
||||
|
||||
eval_dataset = create_input_fn(
|
||||
dataset_tf,
|
||||
self.config['dataset']['data_transforms'],
|
||||
training=False
|
||||
)
|
||||
|
||||
# Distribute dataset
|
||||
eval_dist_dataset = self.strategy.experimental_distribute_dataset(eval_dataset)
|
||||
|
||||
# Run evaluation
|
||||
results = self._run_evaluation(eval_dist_dataset, return_predictions)
|
||||
|
||||
# Calculate summary metrics
|
||||
summary_metrics = self._calculate_summary_metrics(results)
|
||||
|
||||
print(f"Evaluation completed!")
|
||||
print(f"Overall PER: {summary_metrics['overall_per']:.4f}")
|
||||
print(f"Overall Loss: {summary_metrics['overall_loss']:.4f}")
|
||||
print(f"Total trials evaluated: {summary_metrics['total_trials']}")
|
||||
|
||||
# Save results if requested
|
||||
if save_results:
|
||||
self._save_results(results, summary_metrics)
|
||||
|
||||
return {
|
||||
'summary_metrics': summary_metrics,
|
||||
'detailed_results': results if return_predictions else None
|
||||
}
|
||||
|
||||
def _run_evaluation(self, eval_dataset, return_predictions: bool) -> List[Dict[str, Any]]:
|
||||
"""Run evaluation on distributed dataset"""
|
||||
all_results = []
|
||||
batch_idx = 0
|
||||
|
||||
for batch in eval_dataset:
|
||||
batch_results = self.strategy.run(self._evaluation_step, args=(batch, return_predictions))
|
||||
|
||||
# Gather results from all replicas
|
||||
gathered_results = {}
|
||||
for key in batch_results.keys():
|
||||
if key in ['logits', 'features'] and not return_predictions:
|
||||
continue # Skip large tensors if not needed
|
||||
|
||||
values = self.strategy.experimental_local_results(batch_results[key])
|
||||
if key in ['loss', 'edit_distance', 'seq_length']:
|
||||
# Scalar metrics - just take the values
|
||||
gathered_results[key] = [float(v.numpy()) for v in values]
|
||||
else:
|
||||
# Tensor data - concatenate across replicas
|
||||
gathered_results[key] = [v.numpy() for v in values]
|
||||
|
||||
all_results.append(gathered_results)
|
||||
batch_idx += 1
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
print(f"Processed {batch_idx} batches...")
|
||||
|
||||
return all_results
|
||||
|
||||
@tf.function
|
||||
def _evaluation_step(self, batch, return_predictions: bool):
|
||||
"""Single evaluation step"""
|
||||
features = batch['input_features']
|
||||
labels = batch['seq_class_ids']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
phone_seq_lens = batch['phone_seq_lens']
|
||||
day_indices = batch['day_indices']
|
||||
|
||||
# Apply data transformations (no augmentation)
|
||||
from dataset_tf import DataAugmentationTF
|
||||
features_transformed, n_time_steps_transformed = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, self.config['dataset']['data_transforms'], training=False
|
||||
)
|
||||
|
||||
# Calculate adjusted lengths for CTC
|
||||
adjusted_lens = tf.cast(
|
||||
(tf.cast(n_time_steps_transformed, tf.float32) - self.config['model']['patch_size']) /
|
||||
self.config['model']['patch_stride'] + 1,
|
||||
tf.int32
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
logits = self.trainer.model(
|
||||
features_transformed, day_indices, None, False, 'inference', training=False
|
||||
)
|
||||
|
||||
# Calculate loss
|
||||
loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
loss = self.trainer.ctc_loss(loss_input, logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Calculate edit distance for PER
|
||||
predicted_ids = tf.argmax(logits, axis=-1)
|
||||
batch_size = tf.shape(logits)[0]
|
||||
|
||||
# Initialize metrics
|
||||
total_edit_distance = 0
|
||||
total_seq_length = tf.reduce_sum(phone_seq_lens)
|
||||
|
||||
# Decode predictions and calculate edit distance
|
||||
predictions = []
|
||||
targets = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Get prediction for this sample
|
||||
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
||||
|
||||
# Remove consecutive duplicates using tf.py_function for simplicity
|
||||
pred_seq_unique = tf.py_function(
|
||||
func=self._remove_consecutive_duplicates,
|
||||
inp=[pred_seq],
|
||||
Tout=tf.int64
|
||||
)
|
||||
|
||||
# Remove blanks (assuming blank_index=0)
|
||||
pred_seq_clean = tf.boolean_mask(pred_seq_unique, pred_seq_unique != 0)
|
||||
|
||||
# Get true sequence
|
||||
true_seq = labels[i, :phone_seq_lens[i]]
|
||||
|
||||
# Calculate edit distance for this pair
|
||||
if tf.size(pred_seq_clean) > 0 and tf.size(true_seq) > 0:
|
||||
pred_sparse = tf.SparseTensor(
|
||||
indices=tf.expand_dims(tf.range(tf.size(pred_seq_clean), dtype=tf.int64), 1),
|
||||
values=tf.cast(pred_seq_clean, tf.int64),
|
||||
dense_shape=[tf.size(pred_seq_clean, out_type=tf.int64)]
|
||||
)
|
||||
|
||||
true_sparse = tf.SparseTensor(
|
||||
indices=tf.expand_dims(tf.range(tf.size(true_seq), dtype=tf.int64), 1),
|
||||
values=tf.cast(true_seq, tf.int64),
|
||||
dense_shape=[tf.size(true_seq, out_type=tf.int64)]
|
||||
)
|
||||
|
||||
edit_dist = tf.edit_distance(pred_sparse, true_sparse, normalize=False)
|
||||
total_edit_distance += edit_dist
|
||||
|
||||
if return_predictions:
|
||||
predictions.append(pred_seq_clean)
|
||||
targets.append(true_seq)
|
||||
|
||||
result = {
|
||||
'loss': loss,
|
||||
'edit_distance': total_edit_distance,
|
||||
'seq_length': total_seq_length,
|
||||
'day_indices': day_indices,
|
||||
'n_time_steps': n_time_steps,
|
||||
'phone_seq_lens': phone_seq_lens
|
||||
}
|
||||
|
||||
if return_predictions:
|
||||
result.update({
|
||||
'logits': logits,
|
||||
'predictions': predictions,
|
||||
'targets': targets,
|
||||
'features': features
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _remove_consecutive_duplicates(self, seq):
|
||||
"""Remove consecutive duplicate elements from sequence"""
|
||||
seq_np = seq.numpy()
|
||||
if len(seq_np) == 0:
|
||||
return tf.constant([], dtype=tf.int64)
|
||||
|
||||
unique_seq = [seq_np[0]]
|
||||
for i in range(1, len(seq_np)):
|
||||
if seq_np[i] != seq_np[i-1]:
|
||||
unique_seq.append(seq_np[i])
|
||||
|
||||
return tf.constant(unique_seq, dtype=tf.int64)
|
||||
|
||||
def _calculate_summary_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Calculate summary metrics from evaluation results"""
|
||||
total_loss = 0.0
|
||||
total_edit_distance = 0
|
||||
total_seq_length = 0
|
||||
total_trials = 0
|
||||
num_batches = len(results)
|
||||
|
||||
# Day-specific metrics
|
||||
day_metrics = {}
|
||||
|
||||
for batch_results in results:
|
||||
# Sum losses across replicas
|
||||
batch_loss = sum(batch_results['loss'])
|
||||
total_loss += batch_loss
|
||||
|
||||
# Sum edit distances and sequence lengths
|
||||
batch_edit_dist = sum(batch_results['edit_distance'])
|
||||
batch_seq_len = sum(batch_results['seq_length'])
|
||||
|
||||
total_edit_distance += batch_edit_dist
|
||||
total_seq_length += batch_seq_len
|
||||
|
||||
# Count trials
|
||||
for day_indices_replica in batch_results['day_indices']:
|
||||
total_trials += len(day_indices_replica)
|
||||
|
||||
# Track per-day metrics
|
||||
for i, day_idx in enumerate(day_indices_replica):
|
||||
day_idx = int(day_idx)
|
||||
if day_idx not in day_metrics:
|
||||
day_metrics[day_idx] = {'edit_distance': 0, 'seq_length': 0, 'trials': 0}
|
||||
|
||||
day_metrics[day_idx]['trials'] += 1
|
||||
|
||||
# Calculate averages
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
overall_per = total_edit_distance / max(total_seq_length, 1e-6)
|
||||
|
||||
# Calculate per-day PERs
|
||||
day_pers = {}
|
||||
for day_idx, metrics in day_metrics.items():
|
||||
day_per = metrics['edit_distance'] / max(metrics['seq_length'], 1e-6)
|
||||
day_pers[day_idx] = {
|
||||
'per': day_per,
|
||||
'edit_distance': metrics['edit_distance'],
|
||||
'seq_length': metrics['seq_length'],
|
||||
'trials': metrics['trials']
|
||||
}
|
||||
|
||||
return {
|
||||
'overall_per': float(overall_per),
|
||||
'overall_loss': float(avg_loss),
|
||||
'total_edit_distance': int(total_edit_distance),
|
||||
'total_seq_length': int(total_seq_length),
|
||||
'total_trials': total_trials,
|
||||
'num_batches': num_batches,
|
||||
'day_metrics': day_pers
|
||||
}
|
||||
|
||||
def _save_results(self, detailed_results: List[Dict[str, Any]],
|
||||
summary_metrics: Dict[str, Any]):
|
||||
"""Save evaluation results to files"""
|
||||
output_dir = self.config.get('output_dir', './eval_output')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save summary metrics
|
||||
summary_path = os.path.join(output_dir, f'{self.eval_type}_summary_metrics.json')
|
||||
with open(summary_path, 'w') as f:
|
||||
json.dump(summary_metrics, f, indent=2)
|
||||
print(f"Summary metrics saved to: {summary_path}")
|
||||
|
||||
# Save detailed results
|
||||
detailed_path = os.path.join(output_dir, f'{self.eval_type}_detailed_results.pkl')
|
||||
with open(detailed_path, 'wb') as f:
|
||||
pickle.dump(detailed_results, f)
|
||||
print(f"Detailed results saved to: {detailed_path}")
|
||||
|
||||
# Save per-day breakdown
|
||||
if 'day_metrics' in summary_metrics:
|
||||
day_breakdown_path = os.path.join(output_dir, f'{self.eval_type}_day_breakdown.json')
|
||||
with open(day_breakdown_path, 'w') as f:
|
||||
json.dump(summary_metrics['day_metrics'], f, indent=2)
|
||||
print(f"Per-day breakdown saved to: {day_breakdown_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main evaluation function"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Evaluate Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--model_path',
|
||||
required=True,
|
||||
help='Path to trained model checkpoint (without extension)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
default='rnn_args.yaml',
|
||||
help='Path to model configuration file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--data_dir',
|
||||
default=None,
|
||||
help='Override data directory from config'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--eval_type',
|
||||
choices=['test', 'val'],
|
||||
default='test',
|
||||
help='Type of evaluation to run'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
default='./eval_output',
|
||||
help='Directory to save evaluation results'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--save_predictions',
|
||||
action='store_true',
|
||||
help='Save individual predictions and targets'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Override batch size for evaluation'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--sessions',
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Specific sessions to evaluate (overrides config)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup TPU environment
|
||||
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
|
||||
|
||||
# Load configuration
|
||||
if not os.path.exists(args.config_path):
|
||||
raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
|
||||
|
||||
config = OmegaConf.load(args.config_path)
|
||||
|
||||
# Apply overrides
|
||||
if args.data_dir:
|
||||
config.dataset.dataset_dir = args.data_dir
|
||||
if args.batch_size:
|
||||
config.dataset.batch_size = args.batch_size
|
||||
if args.sessions:
|
||||
config.dataset.sessions = args.sessions
|
||||
if args.output_dir:
|
||||
config.output_dir = args.output_dir
|
||||
|
||||
# Validate model checkpoint exists
|
||||
if not os.path.exists(args.model_path + '.weights.h5'):
|
||||
raise FileNotFoundError(f"Model checkpoint not found: {args.model_path}")
|
||||
|
||||
try:
|
||||
# Initialize evaluator
|
||||
evaluator = BrainToTextEvaluatorTF(
|
||||
model_path=args.model_path,
|
||||
config=config,
|
||||
eval_type=args.eval_type
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
results = evaluator.evaluate_dataset(
|
||||
save_results=True,
|
||||
return_predictions=args.save_predictions
|
||||
)
|
||||
|
||||
# Print results
|
||||
metrics = results['summary_metrics']
|
||||
print("\n" + "="*60)
|
||||
print("EVALUATION RESULTS")
|
||||
print("="*60)
|
||||
print(f"Overall PER: {metrics['overall_per']:.6f}")
|
||||
print(f"Overall Loss: {metrics['overall_loss']:.6f}")
|
||||
print(f"Total Edit Distance: {metrics['total_edit_distance']}")
|
||||
print(f"Total Sequence Length: {metrics['total_seq_length']}")
|
||||
print(f"Total Trials: {metrics['total_trials']}")
|
||||
print(f"Batches Processed: {metrics['num_batches']}")
|
||||
|
||||
# Print per-day results if available
|
||||
if 'day_metrics' in metrics and metrics['day_metrics']:
|
||||
print("\nPER-DAY RESULTS:")
|
||||
print("-" * 40)
|
||||
for day_idx, day_metrics in metrics['day_metrics'].items():
|
||||
session_name = config.dataset.sessions[day_idx] if day_idx < len(config.dataset.sessions) else f"Day_{day_idx}"
|
||||
print(f"{session_name}: PER={day_metrics['per']:.6f}, Trials={day_metrics['trials']}")
|
||||
|
||||
print("\nEvaluation completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Evaluation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user