Files
b2txt25/data_analyse/create_phoneme_classification_dataset.py
2025-10-12 09:11:32 +08:00

446 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Convert phoneme segmented data to phoneme classification dataset
将音素分段数据转换为音素分类数据集
"""
import pickle
import numpy as np
import torch
from pathlib import Path
from collections import defaultdict
import os
import sys
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
def load_neural_data_for_trial(session, trial_metadata):
"""Load neural features for a specific trial"""
try:
from model_training.evaluate_model_helpers import load_h5py_file
import pandas as pd
# Try to load the session data
data_dir = Path(__file__).parent.parent / "data" / "hdf5_data_final"
train_file = data_dir / session / "data_train.hdf5"
if not train_file.exists():
return None
# Load CSV for metadata
csv_path = data_dir.parent / "b2txt_dataset_info.csv"
if csv_path.exists():
b2txt_csv_df = pd.read_csv(csv_path)
else:
b2txt_csv_df = None
data = load_h5py_file(str(train_file), b2txt_csv_df)
# Find the matching trial
trial_idx = trial_metadata.get('trial_idx')
if trial_idx is not None and trial_idx < len(data['neural_features']):
return data['neural_features'][trial_idx]
except Exception as e:
print(f"Warning: Could not load neural data for {session}, trial {trial_metadata.get('trial_idx', 'unknown')}: {e}")
return None
def validate_phoneme_against_ground_truth(segment, ctc_data):
"""
Validate a phoneme segment against ground truth sequence labels
返回: (is_valid, error_reason, ground_truth_phoneme)
"""
try:
session = segment['session']
trial_idx = segment.get('trial_idx')
trial_key = (session, trial_idx)
if trial_key not in ctc_data:
return False, "no_trial_data", None
trial_data = ctc_data[trial_key]
original_sequence = trial_data.get('original_sequence')
if original_sequence is None:
return False, "no_ground_truth", None
# Convert sequence IDs to phonemes using LOGIT_TO_PHONEME mapping
try:
from model_training.evaluate_model_helpers import LOGIT_TO_PHONEME
except:
# Fallback phoneme mapping if import fails
LOGIT_TO_PHONEME = {
0: 'BLANK', 1: 'AA', 2: 'AE', 3: 'AH', 4: 'AO', 5: 'AW', 6: 'AY',
7: 'B', 8: 'CH', 9: 'D', 10: 'DH', 11: 'EH', 12: 'ER', 13: 'EY', 14: 'F',
15: 'G', 16: 'HH', 17: 'IH', 18: 'IY', 19: 'JH', 20: 'K', 21: 'L', 22: 'M',
23: 'N', 24: 'NG', 25: 'OW', 26: 'OY', 27: 'P', 28: 'R', 29: 'S', 30: 'SH',
31: 'T', 32: 'TH', 33: 'UH', 34: 'UW', 35: 'V', 36: 'W', 37: 'Y', 38: 'Z',
39: 'ZH', 40: ' | '
}
# Convert ground truth sequence to phonemes (filter out zeros/padding)
ground_truth_phonemes = []
for seq_id in original_sequence:
if seq_id > 0 and seq_id in LOGIT_TO_PHONEME: # Skip padding/blank
ground_truth_phonemes.append(LOGIT_TO_PHONEME[seq_id])
# Find the position of this segment in the predicted sequence
predicted_sequence = trial_data.get('predicted_phonemes', [])
alignment_info = trial_data.get('alignment_info', [])
# Find this segment in the alignment info
segment_phoneme = segment['phoneme']
segment_start = segment['start_time']
segment_end = segment['end_time']
# Find matching alignment segment
segment_position = None
for i, (phoneme, start, end, conf) in enumerate(alignment_info):
if (phoneme == segment_phoneme and
start == segment_start and
end == segment_end):
segment_position = i
break
if segment_position is None:
return False, "segment_not_found_in_alignment", None
# Check if the phoneme at this position matches ground truth
if segment_position < len(ground_truth_phonemes):
expected_phoneme = ground_truth_phonemes[segment_position]
if segment_phoneme == expected_phoneme:
return True, "valid", expected_phoneme
else:
return False, "phoneme_mismatch", expected_phoneme
else:
return False, "position_out_of_range", None
except Exception as e:
return False, f"validation_error: {str(e)}", None
def create_phoneme_classification_dataset():
"""Create a phoneme classification dataset from segmented data with validation"""
# Load the latest phoneme dataset
data_dir = Path("phoneme_segmented_data")
dataset_files = list(data_dir.glob("phoneme_dataset_*.pkl"))
if not dataset_files:
print("No phoneme dataset files found!")
return
latest_dataset = max(dataset_files, key=lambda x: x.stat().st_mtime)
print(f"Loading dataset: {latest_dataset.name}")
with open(latest_dataset, 'rb') as f:
phoneme_data = pickle.load(f)
# Also load the corresponding CTC results for ground truth validation
ctc_file = latest_dataset.parent / latest_dataset.name.replace("phoneme_dataset_", "ctc_results_")
ctc_data = {}
if ctc_file.exists():
with open(ctc_file, 'rb') as f:
ctc_results = pickle.load(f)
# Create lookup dictionary for validation
for result in ctc_results:
key = (result['session'], result['trial_idx'])
ctc_data[key] = result
print(f"Loaded {len(phoneme_data)} phoneme types")
print(f"Associated CTC results: {len(ctc_data)} trials")
# Create classification dataset
classification_data = {
'features': [], # Neural features for each segment
'labels': [], # Phoneme labels
'phoneme_to_id': {}, # Phoneme to numeric ID mapping
'id_to_phoneme': {}, # Numeric ID to phoneme mapping
'metadata': [] # Additional metadata for each sample
}
# Create error tracking
error_data = {
'incorrect_segments': [], # Incorrect phoneme segments
'validation_stats': {} # Validation statistics
}
# Create phoneme to ID mapping
unique_phonemes = sorted(phoneme_data.keys())
for i, phoneme in enumerate(unique_phonemes):
classification_data['phoneme_to_id'][phoneme] = i
classification_data['id_to_phoneme'][i] = phoneme
print(f"\nPhoneme mapping created for {len(unique_phonemes)} phonemes:")
for i, phoneme in enumerate(unique_phonemes[:10]): # Show first 10
print(f" {i:2d}: '{phoneme}'")
if len(unique_phonemes) > 10:
print(f" ... and {len(unique_phonemes) - 10} more")
# Validation and extraction statistics
validation_stats = {
'total_segments': 0,
'valid_segments': 0,
'invalid_segments': 0,
'discarded_neighbors': 0,
'successful_extractions': 0,
'error_reasons': defaultdict(int)
}
print(f"\nValidating phoneme segments against ground truth...")
# First pass: validate all segments and mark invalid ones
segment_validity = {} # Maps (phoneme, segment_idx) -> (is_valid, error_reason, ground_truth)
for phoneme, segments in phoneme_data.items():
print(f"Validating '{phoneme}' ({len(segments)} segments)...")
for segment_idx, segment in enumerate(segments):
validation_stats['total_segments'] += 1
# Validate against ground truth
is_valid, error_reason, ground_truth_phoneme = validate_phoneme_against_ground_truth(segment, ctc_data)
segment_validity[(phoneme, segment_idx)] = (is_valid, error_reason, ground_truth_phoneme)
if is_valid:
validation_stats['valid_segments'] += 1
else:
validation_stats['invalid_segments'] += 1
validation_stats['error_reasons'][error_reason] += 1
# Save error information
error_data['incorrect_segments'].append({
'phoneme': phoneme,
'segment_idx': segment_idx,
'segment': segment,
'predicted_phoneme': phoneme,
'ground_truth_phoneme': ground_truth_phoneme,
'error_reason': error_reason
})
print(f"\nValidation completed:")
print(f" Total segments: {validation_stats['total_segments']}")
print(f" Valid segments: {validation_stats['valid_segments']}")
print(f" Invalid segments: {validation_stats['invalid_segments']}")
print(f" Validation accuracy: {validation_stats['valid_segments']/validation_stats['total_segments']*100:.1f}%")
print(f"\nError breakdown:")
for error_reason, count in validation_stats['error_reasons'].items():
print(f" {error_reason}: {count}")
# Second pass: extract features for valid segments (excluding neighbors of invalid ones)
print(f"\nExtracting neural features for validated segments...")
for phoneme, segments in phoneme_data.items():
phoneme_id = classification_data['phoneme_to_id'][phoneme]
print(f"Processing '{phoneme}' ({len(segments)} segments)...")
for segment_idx, segment in enumerate(segments):
# Check if this segment is valid
is_valid, error_reason, ground_truth_phoneme = segment_validity[(phoneme, segment_idx)]
# Check if neighboring segments are invalid (discard neighbors)
prev_invalid = (segment_idx > 0 and
not segment_validity[(phoneme, segment_idx - 1)][0])
next_invalid = (segment_idx < len(segments) - 1 and
not segment_validity[(phoneme, segment_idx + 1)][0])
if not is_valid:
continue # Skip invalid segments
if prev_invalid or next_invalid:
validation_stats['discarded_neighbors'] += 1
continue # Skip neighbors of invalid segments
# Get trial information
session = segment['session']
trial_key = (session, segment.get('trial_idx'))
# Try to get neural data for this trial
neural_features = None
if trial_key in ctc_data:
# We have the trial data, now extract the segment
trial_metadata = ctc_data[trial_key]
if 'neural_features' in trial_metadata:
neural_features = trial_metadata['neural_features']
else:
# Try to load from HDF5 files
neural_features = load_neural_data_for_trial(session, segment)
if neural_features is not None:
# Extract the specific time segment
start_time = int(segment['start_time'])
end_time = int(segment['end_time'])
# Ensure valid time range
if start_time <= end_time and end_time < len(neural_features):
# Extract neural features for this time segment
segment_features = neural_features[start_time:end_time+1] # Include end_time
# Convert to numpy array and handle different cases
if isinstance(segment_features, torch.Tensor):
segment_features = segment_features.numpy()
elif isinstance(segment_features, list):
segment_features = np.array(segment_features)
# For classification, we need a fixed-size feature vector
# Option 1: Use mean across time steps
if len(segment_features.shape) == 2: # (time, features)
feature_vector = np.mean(segment_features, axis=0)
elif len(segment_features.shape) == 1: # Already 1D
feature_vector = segment_features
else:
print(f"Unexpected feature shape: {segment_features.shape}")
continue
# Add to dataset
classification_data['features'].append(feature_vector)
classification_data['labels'].append(phoneme_id)
classification_data['metadata'].append({
'phoneme': phoneme,
'session': session,
'trial_num': segment.get('trial_num', -1),
'trial_idx': segment.get('trial_idx', -1),
'start_time': start_time,
'end_time': end_time,
'duration': end_time - start_time + 1,
'confidence': segment.get('confidence', 0.0),
'corpus': segment.get('corpus', 'unknown'),
'validated': True
})
validation_stats['successful_extractions'] += 1
# Progress update
if validation_stats['total_segments'] % 1000 == 0:
print(f" Processed {validation_stats['total_segments']} segments, extracted {validation_stats['successful_extractions']} features")
print(f"\nDataset creation completed!")
print(f"Total segments processed: {validation_stats['total_segments']}")
print(f"Valid segments (excluding neighbors): {validation_stats['valid_segments'] - validation_stats['discarded_neighbors']}")
print(f"Discarded neighbor segments: {validation_stats['discarded_neighbors']}")
print(f"Successful feature extractions: {validation_stats['successful_extractions']}")
print(f"Extraction success rate: {validation_stats['successful_extractions']/(validation_stats['valid_segments']-validation_stats['discarded_neighbors'])*100:.1f}%")
if validation_stats['successful_extractions'] == 0:
print("No features were extracted. Check neural data availability.")
return
# Convert to numpy arrays
classification_data['features'] = np.array(classification_data['features'])
classification_data['labels'] = np.array(classification_data['labels'])
print(f"\nFinal validated dataset shape:")
print(f"Features: {classification_data['features'].shape}")
print(f"Labels: {classification_data['labels'].shape}")
# Show class distribution
print(f"\nClass distribution:")
unique_labels, counts = np.unique(classification_data['labels'], return_counts=True)
for label_id, count in zip(unique_labels, counts):
phoneme = classification_data['id_to_phoneme'][label_id]
print(f" {label_id:2d} ('{phoneme}'): {count:4d} samples")
# Save the classification dataset
timestamp = latest_dataset.name.split('_')[-1].replace('.pkl', '')
output_file = f"phoneme_classification_dataset_validated_{timestamp}.pkl"
output_path = data_dir / output_file
# Add validation stats to the dataset
classification_data['validation_stats'] = validation_stats
with open(output_path, 'wb') as f:
pickle.dump(classification_data, f)
print(f"\nValidated classification dataset saved to: {output_file}")
# Save error data separately
error_data['validation_stats'] = validation_stats
error_file = f"phoneme_validation_errors_{timestamp}.pkl"
error_path = data_dir / error_file
with open(error_path, 'wb') as f:
pickle.dump(error_data, f)
print(f"Validation errors saved to: {error_file}")
# Create a simple train/test split example
create_train_test_split(classification_data, data_dir, timestamp)
return classification_data
def create_train_test_split(data, data_dir, timestamp):
"""Create train/test split for the classification dataset"""
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
print(f"\nCreating train/test split...")
X = data['features']
y = data['labels']
metadata = data['metadata']
# Split by session to avoid data leakage
sessions = [meta['session'] for meta in metadata]
unique_sessions = list(set(sessions))
print(f"Available sessions: {len(unique_sessions)}")
if len(unique_sessions) >= 4:
# Use session-based split
train_sessions = unique_sessions[:int(len(unique_sessions) * 0.8)]
test_sessions = unique_sessions[int(len(unique_sessions) * 0.8):]
train_indices = [i for i, meta in enumerate(metadata) if meta['session'] in train_sessions]
test_indices = [i for i, meta in enumerate(metadata) if meta['session'] in test_sessions]
X_train, X_test = X[train_indices], X[test_indices]
y_train, y_test = y[train_indices], y[test_indices]
print(f"Session-based split:")
print(f" Train sessions: {train_sessions}")
print(f" Test sessions: {test_sessions}")
else:
# Use random split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"Random split (stratified):")
print(f" Train samples: {len(X_train)}")
print(f" Test samples: {len(X_test)}")
# Standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Save split data
split_data = {
'X_train': X_train_scaled,
'X_test': X_test_scaled,
'y_train': y_train,
'y_test': y_test,
'scaler': scaler,
'phoneme_to_id': data['phoneme_to_id'],
'id_to_phoneme': data['id_to_phoneme']
}
split_file = f"phoneme_classification_split_{timestamp}.pkl"
split_path = data_dir / split_file
with open(split_path, 'wb') as f:
pickle.dump(split_data, f)
print(f"Train/test split saved to: {split_file}")
if __name__ == "__main__":
try:
classification_data = create_phoneme_classification_dataset()
except Exception as e:
print(f"Error creating classification dataset: {e}")
import traceback
traceback.print_exc()