446 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			446 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | Convert phoneme segmented data to phoneme classification dataset | ||
|  | 将音素分段数据转换为音素分类数据集 | ||
|  | """
 | ||
|  | 
 | ||
|  | import pickle | ||
|  | import numpy as np | ||
|  | import torch | ||
|  | from pathlib import Path | ||
|  | from collections import defaultdict | ||
|  | import os | ||
|  | import sys | ||
|  | 
 | ||
|  | # Add parent directory to path for imports | ||
|  | sys.path.append(str(Path(__file__).parent.parent)) | ||
|  | 
 | ||
|  | def load_neural_data_for_trial(session, trial_metadata): | ||
|  |     """Load neural features for a specific trial""" | ||
|  |     try: | ||
|  |         from model_training.evaluate_model_helpers import load_h5py_file | ||
|  |         import pandas as pd | ||
|  | 
 | ||
|  |         # Try to load the session data | ||
|  |         data_dir = Path(__file__).parent.parent / "data" / "hdf5_data_final" | ||
|  |         train_file = data_dir / session / "data_train.hdf5" | ||
|  | 
 | ||
|  |         if not train_file.exists(): | ||
|  |             return None | ||
|  | 
 | ||
|  |         # Load CSV for metadata | ||
|  |         csv_path = data_dir.parent / "b2txt_dataset_info.csv" | ||
|  |         if csv_path.exists(): | ||
|  |             b2txt_csv_df = pd.read_csv(csv_path) | ||
|  |         else: | ||
|  |             b2txt_csv_df = None | ||
|  | 
 | ||
|  |         data = load_h5py_file(str(train_file), b2txt_csv_df) | ||
|  | 
 | ||
|  |         # Find the matching trial | ||
|  |         trial_idx = trial_metadata.get('trial_idx') | ||
|  |         if trial_idx is not None and trial_idx < len(data['neural_features']): | ||
|  |             return data['neural_features'][trial_idx] | ||
|  | 
 | ||
|  |     except Exception as e: | ||
|  |         print(f"Warning: Could not load neural data for {session}, trial {trial_metadata.get('trial_idx', 'unknown')}: {e}") | ||
|  | 
 | ||
|  |     return None
 | ||
|  | 
 | ||
|  | def validate_phoneme_against_ground_truth(segment, ctc_data): | ||
|  |     """
 | ||
|  |     Validate a phoneme segment against ground truth sequence labels | ||
|  |     返回: (is_valid, error_reason, ground_truth_phoneme) | ||
|  |     """
 | ||
|  |     try: | ||
|  |         session = segment['session'] | ||
|  |         trial_idx = segment.get('trial_idx') | ||
|  |         trial_key = (session, trial_idx) | ||
|  | 
 | ||
|  |         if trial_key not in ctc_data: | ||
|  |             return False, "no_trial_data", None | ||
|  | 
 | ||
|  |         trial_data = ctc_data[trial_key] | ||
|  |         original_sequence = trial_data.get('original_sequence') | ||
|  | 
 | ||
|  |         if original_sequence is None: | ||
|  |             return False, "no_ground_truth", None | ||
|  | 
 | ||
|  |         # Convert sequence IDs to phonemes using LOGIT_TO_PHONEME mapping | ||
|  |         try: | ||
|  |             from model_training.evaluate_model_helpers import LOGIT_TO_PHONEME | ||
|  |         except: | ||
|  |             # Fallback phoneme mapping if import fails | ||
|  |             LOGIT_TO_PHONEME = { | ||
|  |                 0: 'BLANK', 1: 'AA', 2: 'AE', 3: 'AH', 4: 'AO', 5: 'AW', 6: 'AY', | ||
|  |                 7: 'B', 8: 'CH', 9: 'D', 10: 'DH', 11: 'EH', 12: 'ER', 13: 'EY', 14: 'F', | ||
|  |                 15: 'G', 16: 'HH', 17: 'IH', 18: 'IY', 19: 'JH', 20: 'K', 21: 'L', 22: 'M', | ||
|  |                 23: 'N', 24: 'NG', 25: 'OW', 26: 'OY', 27: 'P', 28: 'R', 29: 'S', 30: 'SH', | ||
|  |                 31: 'T', 32: 'TH', 33: 'UH', 34: 'UW', 35: 'V', 36: 'W', 37: 'Y', 38: 'Z', | ||
|  |                 39: 'ZH', 40: ' | ' | ||
|  |             } | ||
|  | 
 | ||
|  |         # Convert ground truth sequence to phonemes (filter out zeros/padding) | ||
|  |         ground_truth_phonemes = [] | ||
|  |         for seq_id in original_sequence: | ||
|  |             if seq_id > 0 and seq_id in LOGIT_TO_PHONEME:  # Skip padding/blank | ||
|  |                 ground_truth_phonemes.append(LOGIT_TO_PHONEME[seq_id]) | ||
|  | 
 | ||
|  |         # Find the position of this segment in the predicted sequence | ||
|  |         predicted_sequence = trial_data.get('predicted_phonemes', []) | ||
|  |         alignment_info = trial_data.get('alignment_info', []) | ||
|  | 
 | ||
|  |         # Find this segment in the alignment info | ||
|  |         segment_phoneme = segment['phoneme'] | ||
|  |         segment_start = segment['start_time'] | ||
|  |         segment_end = segment['end_time'] | ||
|  | 
 | ||
|  |         # Find matching alignment segment | ||
|  |         segment_position = None | ||
|  |         for i, (phoneme, start, end, conf) in enumerate(alignment_info): | ||
|  |             if (phoneme == segment_phoneme and | ||
|  |                 start == segment_start and | ||
|  |                 end == segment_end): | ||
|  |                 segment_position = i | ||
|  |                 break | ||
|  | 
 | ||
|  |         if segment_position is None: | ||
|  |             return False, "segment_not_found_in_alignment", None | ||
|  | 
 | ||
|  |         # Check if the phoneme at this position matches ground truth | ||
|  |         if segment_position < len(ground_truth_phonemes): | ||
|  |             expected_phoneme = ground_truth_phonemes[segment_position] | ||
|  |             if segment_phoneme == expected_phoneme: | ||
|  |                 return True, "valid", expected_phoneme | ||
|  |             else: | ||
|  |                 return False, "phoneme_mismatch", expected_phoneme | ||
|  |         else: | ||
|  |             return False, "position_out_of_range", None | ||
|  | 
 | ||
|  |     except Exception as e: | ||
|  |         return False, f"validation_error: {str(e)}", None | ||
|  | 
 | ||
|  | def create_phoneme_classification_dataset(): | ||
|  |     """Create a phoneme classification dataset from segmented data with validation""" | ||
|  | 
 | ||
|  |     # Load the latest phoneme dataset | ||
|  |     data_dir = Path("phoneme_segmented_data") | ||
|  |     dataset_files = list(data_dir.glob("phoneme_dataset_*.pkl")) | ||
|  | 
 | ||
|  |     if not dataset_files: | ||
|  |         print("No phoneme dataset files found!") | ||
|  |         return | ||
|  | 
 | ||
|  |     latest_dataset = max(dataset_files, key=lambda x: x.stat().st_mtime) | ||
|  |     print(f"Loading dataset: {latest_dataset.name}") | ||
|  | 
 | ||
|  |     with open(latest_dataset, 'rb') as f: | ||
|  |         phoneme_data = pickle.load(f) | ||
|  | 
 | ||
|  |     # Also load the corresponding CTC results for ground truth validation | ||
|  |     ctc_file = latest_dataset.parent / latest_dataset.name.replace("phoneme_dataset_", "ctc_results_") | ||
|  |     ctc_data = {} | ||
|  | 
 | ||
|  |     if ctc_file.exists(): | ||
|  |         with open(ctc_file, 'rb') as f: | ||
|  |             ctc_results = pickle.load(f) | ||
|  |             # Create lookup dictionary for validation | ||
|  |             for result in ctc_results: | ||
|  |                 key = (result['session'], result['trial_idx']) | ||
|  |                 ctc_data[key] = result | ||
|  | 
 | ||
|  |     print(f"Loaded {len(phoneme_data)} phoneme types") | ||
|  |     print(f"Associated CTC results: {len(ctc_data)} trials") | ||
|  | 
 | ||
|  |     # Create classification dataset | ||
|  |     classification_data = { | ||
|  |         'features': [],      # Neural features for each segment | ||
|  |         'labels': [],        # Phoneme labels | ||
|  |         'phoneme_to_id': {}, # Phoneme to numeric ID mapping | ||
|  |         'id_to_phoneme': {}, # Numeric ID to phoneme mapping | ||
|  |         'metadata': []       # Additional metadata for each sample | ||
|  |     } | ||
|  | 
 | ||
|  |     # Create error tracking | ||
|  |     error_data = { | ||
|  |         'incorrect_segments': [],  # Incorrect phoneme segments | ||
|  |         'validation_stats': {}     # Validation statistics | ||
|  |     } | ||
|  | 
 | ||
|  |     # Create phoneme to ID mapping | ||
|  |     unique_phonemes = sorted(phoneme_data.keys()) | ||
|  |     for i, phoneme in enumerate(unique_phonemes): | ||
|  |         classification_data['phoneme_to_id'][phoneme] = i | ||
|  |         classification_data['id_to_phoneme'][i] = phoneme | ||
|  | 
 | ||
|  |     print(f"\nPhoneme mapping created for {len(unique_phonemes)} phonemes:") | ||
|  |     for i, phoneme in enumerate(unique_phonemes[:10]):  # Show first 10 | ||
|  |         print(f"  {i:2d}: '{phoneme}'") | ||
|  |     if len(unique_phonemes) > 10: | ||
|  |         print(f"  ... and {len(unique_phonemes) - 10} more") | ||
|  | 
 | ||
|  |     # Validation and extraction statistics | ||
|  |     validation_stats = { | ||
|  |         'total_segments': 0, | ||
|  |         'valid_segments': 0, | ||
|  |         'invalid_segments': 0, | ||
|  |         'discarded_neighbors': 0, | ||
|  |         'successful_extractions': 0, | ||
|  |         'error_reasons': defaultdict(int) | ||
|  |     } | ||
|  | 
 | ||
|  |     print(f"\nValidating phoneme segments against ground truth...") | ||
|  | 
 | ||
|  |     # First pass: validate all segments and mark invalid ones | ||
|  |     segment_validity = {}  # Maps (phoneme, segment_idx) -> (is_valid, error_reason, ground_truth) | ||
|  | 
 | ||
|  |     for phoneme, segments in phoneme_data.items(): | ||
|  |         print(f"Validating '{phoneme}' ({len(segments)} segments)...") | ||
|  | 
 | ||
|  |         for segment_idx, segment in enumerate(segments): | ||
|  |             validation_stats['total_segments'] += 1 | ||
|  | 
 | ||
|  |             # Validate against ground truth | ||
|  |             is_valid, error_reason, ground_truth_phoneme = validate_phoneme_against_ground_truth(segment, ctc_data) | ||
|  |             segment_validity[(phoneme, segment_idx)] = (is_valid, error_reason, ground_truth_phoneme) | ||
|  | 
 | ||
|  |             if is_valid: | ||
|  |                 validation_stats['valid_segments'] += 1 | ||
|  |             else: | ||
|  |                 validation_stats['invalid_segments'] += 1 | ||
|  |                 validation_stats['error_reasons'][error_reason] += 1 | ||
|  | 
 | ||
|  |                 # Save error information | ||
|  |                 error_data['incorrect_segments'].append({ | ||
|  |                     'phoneme': phoneme, | ||
|  |                     'segment_idx': segment_idx, | ||
|  |                     'segment': segment, | ||
|  |                     'predicted_phoneme': phoneme, | ||
|  |                     'ground_truth_phoneme': ground_truth_phoneme, | ||
|  |                     'error_reason': error_reason | ||
|  |                 }) | ||
|  | 
 | ||
|  |     print(f"\nValidation completed:") | ||
|  |     print(f"  Total segments: {validation_stats['total_segments']}") | ||
|  |     print(f"  Valid segments: {validation_stats['valid_segments']}") | ||
|  |     print(f"  Invalid segments: {validation_stats['invalid_segments']}") | ||
|  |     print(f"  Validation accuracy: {validation_stats['valid_segments']/validation_stats['total_segments']*100:.1f}%") | ||
|  | 
 | ||
|  |     print(f"\nError breakdown:") | ||
|  |     for error_reason, count in validation_stats['error_reasons'].items(): | ||
|  |         print(f"  {error_reason}: {count}") | ||
|  | 
 | ||
|  |     # Second pass: extract features for valid segments (excluding neighbors of invalid ones) | ||
|  |     print(f"\nExtracting neural features for validated segments...") | ||
|  | 
 | ||
|  |     for phoneme, segments in phoneme_data.items(): | ||
|  |         phoneme_id = classification_data['phoneme_to_id'][phoneme] | ||
|  |         print(f"Processing '{phoneme}' ({len(segments)} segments)...") | ||
|  | 
 | ||
|  |         for segment_idx, segment in enumerate(segments): | ||
|  |             # Check if this segment is valid | ||
|  |             is_valid, error_reason, ground_truth_phoneme = segment_validity[(phoneme, segment_idx)] | ||
|  | 
 | ||
|  |             # Check if neighboring segments are invalid (discard neighbors) | ||
|  |             prev_invalid = (segment_idx > 0 and | ||
|  |                           not segment_validity[(phoneme, segment_idx - 1)][0]) | ||
|  |             next_invalid = (segment_idx < len(segments) - 1 and | ||
|  |                           not segment_validity[(phoneme, segment_idx + 1)][0]) | ||
|  | 
 | ||
|  |             if not is_valid: | ||
|  |                 continue  # Skip invalid segments | ||
|  | 
 | ||
|  |             if prev_invalid or next_invalid: | ||
|  |                 validation_stats['discarded_neighbors'] += 1 | ||
|  |                 continue  # Skip neighbors of invalid segments | ||
|  | 
 | ||
|  |             # Get trial information | ||
|  |             session = segment['session'] | ||
|  |             trial_key = (session, segment.get('trial_idx')) | ||
|  | 
 | ||
|  |             # Try to get neural data for this trial | ||
|  |             neural_features = None | ||
|  |             if trial_key in ctc_data: | ||
|  |                 # We have the trial data, now extract the segment | ||
|  |                 trial_metadata = ctc_data[trial_key] | ||
|  |                 if 'neural_features' in trial_metadata: | ||
|  |                     neural_features = trial_metadata['neural_features'] | ||
|  |                 else: | ||
|  |                     # Try to load from HDF5 files | ||
|  |                     neural_features = load_neural_data_for_trial(session, segment) | ||
|  | 
 | ||
|  |             if neural_features is not None: | ||
|  |                 # Extract the specific time segment | ||
|  |                 start_time = int(segment['start_time']) | ||
|  |                 end_time = int(segment['end_time']) | ||
|  | 
 | ||
|  |                 # Ensure valid time range | ||
|  |                 if start_time <= end_time and end_time < len(neural_features): | ||
|  |                     # Extract neural features for this time segment | ||
|  |                     segment_features = neural_features[start_time:end_time+1]  # Include end_time | ||
|  | 
 | ||
|  |                     # Convert to numpy array and handle different cases | ||
|  |                     if isinstance(segment_features, torch.Tensor): | ||
|  |                         segment_features = segment_features.numpy() | ||
|  |                     elif isinstance(segment_features, list): | ||
|  |                         segment_features = np.array(segment_features) | ||
|  | 
 | ||
|  |                     # For classification, we need a fixed-size feature vector | ||
|  |                     # Option 1: Use mean across time steps | ||
|  |                     if len(segment_features.shape) == 2:  # (time, features) | ||
|  |                         feature_vector = np.mean(segment_features, axis=0) | ||
|  |                     elif len(segment_features.shape) == 1:  # Already 1D | ||
|  |                         feature_vector = segment_features | ||
|  |                     else: | ||
|  |                         print(f"Unexpected feature shape: {segment_features.shape}") | ||
|  |                         continue | ||
|  | 
 | ||
|  |                     # Add to dataset | ||
|  |                     classification_data['features'].append(feature_vector) | ||
|  |                     classification_data['labels'].append(phoneme_id) | ||
|  |                     classification_data['metadata'].append({ | ||
|  |                         'phoneme': phoneme, | ||
|  |                         'session': session, | ||
|  |                         'trial_num': segment.get('trial_num', -1), | ||
|  |                         'trial_idx': segment.get('trial_idx', -1), | ||
|  |                         'start_time': start_time, | ||
|  |                         'end_time': end_time, | ||
|  |                         'duration': end_time - start_time + 1, | ||
|  |                         'confidence': segment.get('confidence', 0.0), | ||
|  |                         'corpus': segment.get('corpus', 'unknown'), | ||
|  |                         'validated': True | ||
|  |                     }) | ||
|  | 
 | ||
|  |                     validation_stats['successful_extractions'] += 1 | ||
|  | 
 | ||
|  |             # Progress update | ||
|  |             if validation_stats['total_segments'] % 1000 == 0: | ||
|  |                 print(f"  Processed {validation_stats['total_segments']} segments, extracted {validation_stats['successful_extractions']} features") | ||
|  | 
 | ||
|  |     print(f"\nDataset creation completed!") | ||
|  |     print(f"Total segments processed: {validation_stats['total_segments']}") | ||
|  |     print(f"Valid segments (excluding neighbors): {validation_stats['valid_segments'] - validation_stats['discarded_neighbors']}") | ||
|  |     print(f"Discarded neighbor segments: {validation_stats['discarded_neighbors']}") | ||
|  |     print(f"Successful feature extractions: {validation_stats['successful_extractions']}") | ||
|  |     print(f"Extraction success rate: {validation_stats['successful_extractions']/(validation_stats['valid_segments']-validation_stats['discarded_neighbors'])*100:.1f}%") | ||
|  | 
 | ||
|  |     if validation_stats['successful_extractions'] == 0: | ||
|  |         print("No features were extracted. Check neural data availability.") | ||
|  |         return | ||
|  | 
 | ||
|  |     # Convert to numpy arrays | ||
|  |     classification_data['features'] = np.array(classification_data['features']) | ||
|  |     classification_data['labels'] = np.array(classification_data['labels']) | ||
|  | 
 | ||
|  |     print(f"\nFinal validated dataset shape:") | ||
|  |     print(f"Features: {classification_data['features'].shape}") | ||
|  |     print(f"Labels: {classification_data['labels'].shape}") | ||
|  | 
 | ||
|  |     # Show class distribution | ||
|  |     print(f"\nClass distribution:") | ||
|  |     unique_labels, counts = np.unique(classification_data['labels'], return_counts=True) | ||
|  |     for label_id, count in zip(unique_labels, counts): | ||
|  |         phoneme = classification_data['id_to_phoneme'][label_id] | ||
|  |         print(f"  {label_id:2d} ('{phoneme}'): {count:4d} samples") | ||
|  | 
 | ||
|  |     # Save the classification dataset | ||
|  |     timestamp = latest_dataset.name.split('_')[-1].replace('.pkl', '') | ||
|  |     output_file = f"phoneme_classification_dataset_validated_{timestamp}.pkl" | ||
|  |     output_path = data_dir / output_file | ||
|  | 
 | ||
|  |     # Add validation stats to the dataset | ||
|  |     classification_data['validation_stats'] = validation_stats | ||
|  | 
 | ||
|  |     with open(output_path, 'wb') as f: | ||
|  |         pickle.dump(classification_data, f) | ||
|  | 
 | ||
|  |     print(f"\nValidated classification dataset saved to: {output_file}") | ||
|  | 
 | ||
|  |     # Save error data separately | ||
|  |     error_data['validation_stats'] = validation_stats | ||
|  |     error_file = f"phoneme_validation_errors_{timestamp}.pkl" | ||
|  |     error_path = data_dir / error_file | ||
|  | 
 | ||
|  |     with open(error_path, 'wb') as f: | ||
|  |         pickle.dump(error_data, f) | ||
|  | 
 | ||
|  |     print(f"Validation errors saved to: {error_file}") | ||
|  | 
 | ||
|  |     # Create a simple train/test split example | ||
|  |     create_train_test_split(classification_data, data_dir, timestamp) | ||
|  | 
 | ||
|  |     return classification_data | ||
|  | 
 | ||
|  | def create_train_test_split(data, data_dir, timestamp): | ||
|  |     """Create train/test split for the classification dataset""" | ||
|  | 
 | ||
|  |     from sklearn.model_selection import train_test_split | ||
|  |     from sklearn.preprocessing import StandardScaler | ||
|  | 
 | ||
|  |     print(f"\nCreating train/test split...") | ||
|  | 
 | ||
|  |     X = data['features'] | ||
|  |     y = data['labels'] | ||
|  |     metadata = data['metadata'] | ||
|  | 
 | ||
|  |     # Split by session to avoid data leakage | ||
|  |     sessions = [meta['session'] for meta in metadata] | ||
|  |     unique_sessions = list(set(sessions)) | ||
|  | 
 | ||
|  |     print(f"Available sessions: {len(unique_sessions)}") | ||
|  | 
 | ||
|  |     if len(unique_sessions) >= 4: | ||
|  |         # Use session-based split | ||
|  |         train_sessions = unique_sessions[:int(len(unique_sessions) * 0.8)] | ||
|  |         test_sessions = unique_sessions[int(len(unique_sessions) * 0.8):] | ||
|  | 
 | ||
|  |         train_indices = [i for i, meta in enumerate(metadata) if meta['session'] in train_sessions] | ||
|  |         test_indices = [i for i, meta in enumerate(metadata) if meta['session'] in test_sessions] | ||
|  | 
 | ||
|  |         X_train, X_test = X[train_indices], X[test_indices] | ||
|  |         y_train, y_test = y[train_indices], y[test_indices] | ||
|  | 
 | ||
|  |         print(f"Session-based split:") | ||
|  |         print(f"  Train sessions: {train_sessions}") | ||
|  |         print(f"  Test sessions: {test_sessions}") | ||
|  |     else: | ||
|  |         # Use random split | ||
|  |         X_train, X_test, y_train, y_test = train_test_split( | ||
|  |             X, y, test_size=0.2, random_state=42, stratify=y | ||
|  |         ) | ||
|  |         print(f"Random split (stratified):") | ||
|  | 
 | ||
|  |     print(f"  Train samples: {len(X_train)}") | ||
|  |     print(f"  Test samples: {len(X_test)}") | ||
|  | 
 | ||
|  |     # Standardize features | ||
|  |     scaler = StandardScaler() | ||
|  |     X_train_scaled = scaler.fit_transform(X_train) | ||
|  |     X_test_scaled = scaler.transform(X_test) | ||
|  | 
 | ||
|  |     # Save split data | ||
|  |     split_data = { | ||
|  |         'X_train': X_train_scaled, | ||
|  |         'X_test': X_test_scaled, | ||
|  |         'y_train': y_train, | ||
|  |         'y_test': y_test, | ||
|  |         'scaler': scaler, | ||
|  |         'phoneme_to_id': data['phoneme_to_id'], | ||
|  |         'id_to_phoneme': data['id_to_phoneme'] | ||
|  |     } | ||
|  | 
 | ||
|  |     split_file = f"phoneme_classification_split_{timestamp}.pkl" | ||
|  |     split_path = data_dir / split_file | ||
|  | 
 | ||
|  |     with open(split_path, 'wb') as f: | ||
|  |         pickle.dump(split_data, f) | ||
|  | 
 | ||
|  |     print(f"Train/test split saved to: {split_file}") | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     try: | ||
|  |         classification_data = create_phoneme_classification_dataset() | ||
|  |     except Exception as e: | ||
|  |         print(f"Error creating classification dataset: {e}") | ||
|  |         import traceback | ||
|  |         traceback.print_exc() |