446 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			446 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| """
 | |
| Convert phoneme segmented data to phoneme classification dataset
 | |
| 将音素分段数据转换为音素分类数据集
 | |
| """
 | |
| 
 | |
| import pickle
 | |
| import numpy as np
 | |
| import torch
 | |
| from pathlib import Path
 | |
| from collections import defaultdict
 | |
| import os
 | |
| import sys
 | |
| 
 | |
| # Add parent directory to path for imports
 | |
| sys.path.append(str(Path(__file__).parent.parent))
 | |
| 
 | |
| def load_neural_data_for_trial(session, trial_metadata):
 | |
|     """Load neural features for a specific trial"""
 | |
|     try:
 | |
|         from model_training.evaluate_model_helpers import load_h5py_file
 | |
|         import pandas as pd
 | |
| 
 | |
|         # Try to load the session data
 | |
|         data_dir = Path(__file__).parent.parent / "data" / "hdf5_data_final"
 | |
|         train_file = data_dir / session / "data_train.hdf5"
 | |
| 
 | |
|         if not train_file.exists():
 | |
|             return None
 | |
| 
 | |
|         # Load CSV for metadata
 | |
|         csv_path = data_dir.parent / "b2txt_dataset_info.csv"
 | |
|         if csv_path.exists():
 | |
|             b2txt_csv_df = pd.read_csv(csv_path)
 | |
|         else:
 | |
|             b2txt_csv_df = None
 | |
| 
 | |
|         data = load_h5py_file(str(train_file), b2txt_csv_df)
 | |
| 
 | |
|         # Find the matching trial
 | |
|         trial_idx = trial_metadata.get('trial_idx')
 | |
|         if trial_idx is not None and trial_idx < len(data['neural_features']):
 | |
|             return data['neural_features'][trial_idx]
 | |
| 
 | |
|     except Exception as e:
 | |
|         print(f"Warning: Could not load neural data for {session}, trial {trial_metadata.get('trial_idx', 'unknown')}: {e}")
 | |
| 
 | |
|     return None
 | |
| 
 | |
| def validate_phoneme_against_ground_truth(segment, ctc_data):
 | |
|     """
 | |
|     Validate a phoneme segment against ground truth sequence labels
 | |
|     返回: (is_valid, error_reason, ground_truth_phoneme)
 | |
|     """
 | |
|     try:
 | |
|         session = segment['session']
 | |
|         trial_idx = segment.get('trial_idx')
 | |
|         trial_key = (session, trial_idx)
 | |
| 
 | |
|         if trial_key not in ctc_data:
 | |
|             return False, "no_trial_data", None
 | |
| 
 | |
|         trial_data = ctc_data[trial_key]
 | |
|         original_sequence = trial_data.get('original_sequence')
 | |
| 
 | |
|         if original_sequence is None:
 | |
|             return False, "no_ground_truth", None
 | |
| 
 | |
|         # Convert sequence IDs to phonemes using LOGIT_TO_PHONEME mapping
 | |
|         try:
 | |
|             from model_training.evaluate_model_helpers import LOGIT_TO_PHONEME
 | |
|         except:
 | |
|             # Fallback phoneme mapping if import fails
 | |
|             LOGIT_TO_PHONEME = {
 | |
|                 0: 'BLANK', 1: 'AA', 2: 'AE', 3: 'AH', 4: 'AO', 5: 'AW', 6: 'AY',
 | |
|                 7: 'B', 8: 'CH', 9: 'D', 10: 'DH', 11: 'EH', 12: 'ER', 13: 'EY', 14: 'F',
 | |
|                 15: 'G', 16: 'HH', 17: 'IH', 18: 'IY', 19: 'JH', 20: 'K', 21: 'L', 22: 'M',
 | |
|                 23: 'N', 24: 'NG', 25: 'OW', 26: 'OY', 27: 'P', 28: 'R', 29: 'S', 30: 'SH',
 | |
|                 31: 'T', 32: 'TH', 33: 'UH', 34: 'UW', 35: 'V', 36: 'W', 37: 'Y', 38: 'Z',
 | |
|                 39: 'ZH', 40: ' | '
 | |
|             }
 | |
| 
 | |
|         # Convert ground truth sequence to phonemes (filter out zeros/padding)
 | |
|         ground_truth_phonemes = []
 | |
|         for seq_id in original_sequence:
 | |
|             if seq_id > 0 and seq_id in LOGIT_TO_PHONEME:  # Skip padding/blank
 | |
|                 ground_truth_phonemes.append(LOGIT_TO_PHONEME[seq_id])
 | |
| 
 | |
|         # Find the position of this segment in the predicted sequence
 | |
|         predicted_sequence = trial_data.get('predicted_phonemes', [])
 | |
|         alignment_info = trial_data.get('alignment_info', [])
 | |
| 
 | |
|         # Find this segment in the alignment info
 | |
|         segment_phoneme = segment['phoneme']
 | |
|         segment_start = segment['start_time']
 | |
|         segment_end = segment['end_time']
 | |
| 
 | |
|         # Find matching alignment segment
 | |
|         segment_position = None
 | |
|         for i, (phoneme, start, end, conf) in enumerate(alignment_info):
 | |
|             if (phoneme == segment_phoneme and
 | |
|                 start == segment_start and
 | |
|                 end == segment_end):
 | |
|                 segment_position = i
 | |
|                 break
 | |
| 
 | |
|         if segment_position is None:
 | |
|             return False, "segment_not_found_in_alignment", None
 | |
| 
 | |
|         # Check if the phoneme at this position matches ground truth
 | |
|         if segment_position < len(ground_truth_phonemes):
 | |
|             expected_phoneme = ground_truth_phonemes[segment_position]
 | |
|             if segment_phoneme == expected_phoneme:
 | |
|                 return True, "valid", expected_phoneme
 | |
|             else:
 | |
|                 return False, "phoneme_mismatch", expected_phoneme
 | |
|         else:
 | |
|             return False, "position_out_of_range", None
 | |
| 
 | |
|     except Exception as e:
 | |
|         return False, f"validation_error: {str(e)}", None
 | |
| 
 | |
| def create_phoneme_classification_dataset():
 | |
|     """Create a phoneme classification dataset from segmented data with validation"""
 | |
| 
 | |
|     # Load the latest phoneme dataset
 | |
|     data_dir = Path("phoneme_segmented_data")
 | |
|     dataset_files = list(data_dir.glob("phoneme_dataset_*.pkl"))
 | |
| 
 | |
|     if not dataset_files:
 | |
|         print("No phoneme dataset files found!")
 | |
|         return
 | |
| 
 | |
|     latest_dataset = max(dataset_files, key=lambda x: x.stat().st_mtime)
 | |
|     print(f"Loading dataset: {latest_dataset.name}")
 | |
| 
 | |
|     with open(latest_dataset, 'rb') as f:
 | |
|         phoneme_data = pickle.load(f)
 | |
| 
 | |
|     # Also load the corresponding CTC results for ground truth validation
 | |
|     ctc_file = latest_dataset.parent / latest_dataset.name.replace("phoneme_dataset_", "ctc_results_")
 | |
|     ctc_data = {}
 | |
| 
 | |
|     if ctc_file.exists():
 | |
|         with open(ctc_file, 'rb') as f:
 | |
|             ctc_results = pickle.load(f)
 | |
|             # Create lookup dictionary for validation
 | |
|             for result in ctc_results:
 | |
|                 key = (result['session'], result['trial_idx'])
 | |
|                 ctc_data[key] = result
 | |
| 
 | |
|     print(f"Loaded {len(phoneme_data)} phoneme types")
 | |
|     print(f"Associated CTC results: {len(ctc_data)} trials")
 | |
| 
 | |
|     # Create classification dataset
 | |
|     classification_data = {
 | |
|         'features': [],      # Neural features for each segment
 | |
|         'labels': [],        # Phoneme labels
 | |
|         'phoneme_to_id': {}, # Phoneme to numeric ID mapping
 | |
|         'id_to_phoneme': {}, # Numeric ID to phoneme mapping
 | |
|         'metadata': []       # Additional metadata for each sample
 | |
|     }
 | |
| 
 | |
|     # Create error tracking
 | |
|     error_data = {
 | |
|         'incorrect_segments': [],  # Incorrect phoneme segments
 | |
|         'validation_stats': {}     # Validation statistics
 | |
|     }
 | |
| 
 | |
|     # Create phoneme to ID mapping
 | |
|     unique_phonemes = sorted(phoneme_data.keys())
 | |
|     for i, phoneme in enumerate(unique_phonemes):
 | |
|         classification_data['phoneme_to_id'][phoneme] = i
 | |
|         classification_data['id_to_phoneme'][i] = phoneme
 | |
| 
 | |
|     print(f"\nPhoneme mapping created for {len(unique_phonemes)} phonemes:")
 | |
|     for i, phoneme in enumerate(unique_phonemes[:10]):  # Show first 10
 | |
|         print(f"  {i:2d}: '{phoneme}'")
 | |
|     if len(unique_phonemes) > 10:
 | |
|         print(f"  ... and {len(unique_phonemes) - 10} more")
 | |
| 
 | |
|     # Validation and extraction statistics
 | |
|     validation_stats = {
 | |
|         'total_segments': 0,
 | |
|         'valid_segments': 0,
 | |
|         'invalid_segments': 0,
 | |
|         'discarded_neighbors': 0,
 | |
|         'successful_extractions': 0,
 | |
|         'error_reasons': defaultdict(int)
 | |
|     }
 | |
| 
 | |
|     print(f"\nValidating phoneme segments against ground truth...")
 | |
| 
 | |
|     # First pass: validate all segments and mark invalid ones
 | |
|     segment_validity = {}  # Maps (phoneme, segment_idx) -> (is_valid, error_reason, ground_truth)
 | |
| 
 | |
|     for phoneme, segments in phoneme_data.items():
 | |
|         print(f"Validating '{phoneme}' ({len(segments)} segments)...")
 | |
| 
 | |
|         for segment_idx, segment in enumerate(segments):
 | |
|             validation_stats['total_segments'] += 1
 | |
| 
 | |
|             # Validate against ground truth
 | |
|             is_valid, error_reason, ground_truth_phoneme = validate_phoneme_against_ground_truth(segment, ctc_data)
 | |
|             segment_validity[(phoneme, segment_idx)] = (is_valid, error_reason, ground_truth_phoneme)
 | |
| 
 | |
|             if is_valid:
 | |
|                 validation_stats['valid_segments'] += 1
 | |
|             else:
 | |
|                 validation_stats['invalid_segments'] += 1
 | |
|                 validation_stats['error_reasons'][error_reason] += 1
 | |
| 
 | |
|                 # Save error information
 | |
|                 error_data['incorrect_segments'].append({
 | |
|                     'phoneme': phoneme,
 | |
|                     'segment_idx': segment_idx,
 | |
|                     'segment': segment,
 | |
|                     'predicted_phoneme': phoneme,
 | |
|                     'ground_truth_phoneme': ground_truth_phoneme,
 | |
|                     'error_reason': error_reason
 | |
|                 })
 | |
| 
 | |
|     print(f"\nValidation completed:")
 | |
|     print(f"  Total segments: {validation_stats['total_segments']}")
 | |
|     print(f"  Valid segments: {validation_stats['valid_segments']}")
 | |
|     print(f"  Invalid segments: {validation_stats['invalid_segments']}")
 | |
|     print(f"  Validation accuracy: {validation_stats['valid_segments']/validation_stats['total_segments']*100:.1f}%")
 | |
| 
 | |
|     print(f"\nError breakdown:")
 | |
|     for error_reason, count in validation_stats['error_reasons'].items():
 | |
|         print(f"  {error_reason}: {count}")
 | |
| 
 | |
|     # Second pass: extract features for valid segments (excluding neighbors of invalid ones)
 | |
|     print(f"\nExtracting neural features for validated segments...")
 | |
| 
 | |
|     for phoneme, segments in phoneme_data.items():
 | |
|         phoneme_id = classification_data['phoneme_to_id'][phoneme]
 | |
|         print(f"Processing '{phoneme}' ({len(segments)} segments)...")
 | |
| 
 | |
|         for segment_idx, segment in enumerate(segments):
 | |
|             # Check if this segment is valid
 | |
|             is_valid, error_reason, ground_truth_phoneme = segment_validity[(phoneme, segment_idx)]
 | |
| 
 | |
|             # Check if neighboring segments are invalid (discard neighbors)
 | |
|             prev_invalid = (segment_idx > 0 and
 | |
|                           not segment_validity[(phoneme, segment_idx - 1)][0])
 | |
|             next_invalid = (segment_idx < len(segments) - 1 and
 | |
|                           not segment_validity[(phoneme, segment_idx + 1)][0])
 | |
| 
 | |
|             if not is_valid:
 | |
|                 continue  # Skip invalid segments
 | |
| 
 | |
|             if prev_invalid or next_invalid:
 | |
|                 validation_stats['discarded_neighbors'] += 1
 | |
|                 continue  # Skip neighbors of invalid segments
 | |
| 
 | |
|             # Get trial information
 | |
|             session = segment['session']
 | |
|             trial_key = (session, segment.get('trial_idx'))
 | |
| 
 | |
|             # Try to get neural data for this trial
 | |
|             neural_features = None
 | |
|             if trial_key in ctc_data:
 | |
|                 # We have the trial data, now extract the segment
 | |
|                 trial_metadata = ctc_data[trial_key]
 | |
|                 if 'neural_features' in trial_metadata:
 | |
|                     neural_features = trial_metadata['neural_features']
 | |
|                 else:
 | |
|                     # Try to load from HDF5 files
 | |
|                     neural_features = load_neural_data_for_trial(session, segment)
 | |
| 
 | |
|             if neural_features is not None:
 | |
|                 # Extract the specific time segment
 | |
|                 start_time = int(segment['start_time'])
 | |
|                 end_time = int(segment['end_time'])
 | |
| 
 | |
|                 # Ensure valid time range
 | |
|                 if start_time <= end_time and end_time < len(neural_features):
 | |
|                     # Extract neural features for this time segment
 | |
|                     segment_features = neural_features[start_time:end_time+1]  # Include end_time
 | |
| 
 | |
|                     # Convert to numpy array and handle different cases
 | |
|                     if isinstance(segment_features, torch.Tensor):
 | |
|                         segment_features = segment_features.numpy()
 | |
|                     elif isinstance(segment_features, list):
 | |
|                         segment_features = np.array(segment_features)
 | |
| 
 | |
|                     # For classification, we need a fixed-size feature vector
 | |
|                     # Option 1: Use mean across time steps
 | |
|                     if len(segment_features.shape) == 2:  # (time, features)
 | |
|                         feature_vector = np.mean(segment_features, axis=0)
 | |
|                     elif len(segment_features.shape) == 1:  # Already 1D
 | |
|                         feature_vector = segment_features
 | |
|                     else:
 | |
|                         print(f"Unexpected feature shape: {segment_features.shape}")
 | |
|                         continue
 | |
| 
 | |
|                     # Add to dataset
 | |
|                     classification_data['features'].append(feature_vector)
 | |
|                     classification_data['labels'].append(phoneme_id)
 | |
|                     classification_data['metadata'].append({
 | |
|                         'phoneme': phoneme,
 | |
|                         'session': session,
 | |
|                         'trial_num': segment.get('trial_num', -1),
 | |
|                         'trial_idx': segment.get('trial_idx', -1),
 | |
|                         'start_time': start_time,
 | |
|                         'end_time': end_time,
 | |
|                         'duration': end_time - start_time + 1,
 | |
|                         'confidence': segment.get('confidence', 0.0),
 | |
|                         'corpus': segment.get('corpus', 'unknown'),
 | |
|                         'validated': True
 | |
|                     })
 | |
| 
 | |
|                     validation_stats['successful_extractions'] += 1
 | |
| 
 | |
|             # Progress update
 | |
|             if validation_stats['total_segments'] % 1000 == 0:
 | |
|                 print(f"  Processed {validation_stats['total_segments']} segments, extracted {validation_stats['successful_extractions']} features")
 | |
| 
 | |
|     print(f"\nDataset creation completed!")
 | |
|     print(f"Total segments processed: {validation_stats['total_segments']}")
 | |
|     print(f"Valid segments (excluding neighbors): {validation_stats['valid_segments'] - validation_stats['discarded_neighbors']}")
 | |
|     print(f"Discarded neighbor segments: {validation_stats['discarded_neighbors']}")
 | |
|     print(f"Successful feature extractions: {validation_stats['successful_extractions']}")
 | |
|     print(f"Extraction success rate: {validation_stats['successful_extractions']/(validation_stats['valid_segments']-validation_stats['discarded_neighbors'])*100:.1f}%")
 | |
| 
 | |
|     if validation_stats['successful_extractions'] == 0:
 | |
|         print("No features were extracted. Check neural data availability.")
 | |
|         return
 | |
| 
 | |
|     # Convert to numpy arrays
 | |
|     classification_data['features'] = np.array(classification_data['features'])
 | |
|     classification_data['labels'] = np.array(classification_data['labels'])
 | |
| 
 | |
|     print(f"\nFinal validated dataset shape:")
 | |
|     print(f"Features: {classification_data['features'].shape}")
 | |
|     print(f"Labels: {classification_data['labels'].shape}")
 | |
| 
 | |
|     # Show class distribution
 | |
|     print(f"\nClass distribution:")
 | |
|     unique_labels, counts = np.unique(classification_data['labels'], return_counts=True)
 | |
|     for label_id, count in zip(unique_labels, counts):
 | |
|         phoneme = classification_data['id_to_phoneme'][label_id]
 | |
|         print(f"  {label_id:2d} ('{phoneme}'): {count:4d} samples")
 | |
| 
 | |
|     # Save the classification dataset
 | |
|     timestamp = latest_dataset.name.split('_')[-1].replace('.pkl', '')
 | |
|     output_file = f"phoneme_classification_dataset_validated_{timestamp}.pkl"
 | |
|     output_path = data_dir / output_file
 | |
| 
 | |
|     # Add validation stats to the dataset
 | |
|     classification_data['validation_stats'] = validation_stats
 | |
| 
 | |
|     with open(output_path, 'wb') as f:
 | |
|         pickle.dump(classification_data, f)
 | |
| 
 | |
|     print(f"\nValidated classification dataset saved to: {output_file}")
 | |
| 
 | |
|     # Save error data separately
 | |
|     error_data['validation_stats'] = validation_stats
 | |
|     error_file = f"phoneme_validation_errors_{timestamp}.pkl"
 | |
|     error_path = data_dir / error_file
 | |
| 
 | |
|     with open(error_path, 'wb') as f:
 | |
|         pickle.dump(error_data, f)
 | |
| 
 | |
|     print(f"Validation errors saved to: {error_file}")
 | |
| 
 | |
|     # Create a simple train/test split example
 | |
|     create_train_test_split(classification_data, data_dir, timestamp)
 | |
| 
 | |
|     return classification_data
 | |
| 
 | |
| def create_train_test_split(data, data_dir, timestamp):
 | |
|     """Create train/test split for the classification dataset"""
 | |
| 
 | |
|     from sklearn.model_selection import train_test_split
 | |
|     from sklearn.preprocessing import StandardScaler
 | |
| 
 | |
|     print(f"\nCreating train/test split...")
 | |
| 
 | |
|     X = data['features']
 | |
|     y = data['labels']
 | |
|     metadata = data['metadata']
 | |
| 
 | |
|     # Split by session to avoid data leakage
 | |
|     sessions = [meta['session'] for meta in metadata]
 | |
|     unique_sessions = list(set(sessions))
 | |
| 
 | |
|     print(f"Available sessions: {len(unique_sessions)}")
 | |
| 
 | |
|     if len(unique_sessions) >= 4:
 | |
|         # Use session-based split
 | |
|         train_sessions = unique_sessions[:int(len(unique_sessions) * 0.8)]
 | |
|         test_sessions = unique_sessions[int(len(unique_sessions) * 0.8):]
 | |
| 
 | |
|         train_indices = [i for i, meta in enumerate(metadata) if meta['session'] in train_sessions]
 | |
|         test_indices = [i for i, meta in enumerate(metadata) if meta['session'] in test_sessions]
 | |
| 
 | |
|         X_train, X_test = X[train_indices], X[test_indices]
 | |
|         y_train, y_test = y[train_indices], y[test_indices]
 | |
| 
 | |
|         print(f"Session-based split:")
 | |
|         print(f"  Train sessions: {train_sessions}")
 | |
|         print(f"  Test sessions: {test_sessions}")
 | |
|     else:
 | |
|         # Use random split
 | |
|         X_train, X_test, y_train, y_test = train_test_split(
 | |
|             X, y, test_size=0.2, random_state=42, stratify=y
 | |
|         )
 | |
|         print(f"Random split (stratified):")
 | |
| 
 | |
|     print(f"  Train samples: {len(X_train)}")
 | |
|     print(f"  Test samples: {len(X_test)}")
 | |
| 
 | |
|     # Standardize features
 | |
|     scaler = StandardScaler()
 | |
|     X_train_scaled = scaler.fit_transform(X_train)
 | |
|     X_test_scaled = scaler.transform(X_test)
 | |
| 
 | |
|     # Save split data
 | |
|     split_data = {
 | |
|         'X_train': X_train_scaled,
 | |
|         'X_test': X_test_scaled,
 | |
|         'y_train': y_train,
 | |
|         'y_test': y_test,
 | |
|         'scaler': scaler,
 | |
|         'phoneme_to_id': data['phoneme_to_id'],
 | |
|         'id_to_phoneme': data['id_to_phoneme']
 | |
|     }
 | |
| 
 | |
|     split_file = f"phoneme_classification_split_{timestamp}.pkl"
 | |
|     split_path = data_dir / split_file
 | |
| 
 | |
|     with open(split_path, 'wb') as f:
 | |
|         pickle.dump(split_data, f)
 | |
| 
 | |
|     print(f"Train/test split saved to: {split_file}")
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     try:
 | |
|         classification_data = create_phoneme_classification_dataset()
 | |
|     except Exception as e:
 | |
|         print(f"Error creating classification dataset: {e}")
 | |
|         import traceback
 | |
|         traceback.print_exc() | 
