Files
b2txt25/data_analyse/restore_phoneme_data.py

167 lines
6.0 KiB
Python
Raw Normal View History

2025-10-12 09:11:32 +08:00
#!/usr/bin/env python3
"""
根据转换后的时间戳从原始数据集中还原音素对应的神经信号数据
"""
import pickle
import numpy as np
import h5py
import os
from pathlib import Path
def load_converted_phoneme_dataset(pkl_path):
"""加载转换后的音素数据集"""
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_h5_session_data(session_path):
"""加载HDF5会话数据"""
with h5py.File(session_path, 'r') as f:
# 根据项目结构,神经特征应该在某个键下
print(f"HDF5文件键: {list(f.keys())}")
# 尝试找到神经特征数据
if 'neuralFeatures' in f:
features = f['neuralFeatures'][:]
elif 'neural_features' in f:
features = f['neural_features'][:]
elif 'features' in f:
features = f['features'][:]
else:
# 检查所有键的形状
for key in f.keys():
try:
shape = f[key].shape
print(f" {key}: shape {shape}")
# 寻找形状像神经特征的数据 (时间步, 特征数)
if len(shape) == 2 and shape[1] == 512:
features = f[key][:]
print(f"使用 {key} 作为神经特征数据")
break
except:
print(f" {key}: 无法获取形状")
else:
raise ValueError("未找到神经特征数据")
print(f"神经特征形状: {features.shape}")
return features
def find_session_file(session_name, data_dir):
"""根据会话名称找到对应的HDF5文件"""
data_path = Path(data_dir)
# 尝试不同的文件命名模式
possible_patterns = [
f"{session_name}.h5",
f"{session_name}.hdf5",
f"*{session_name}*.h5",
f"*{session_name}*.hdf5"
]
for pattern in possible_patterns:
files = list(data_path.glob(pattern))
if files:
return files[0]
# 如果没找到,列出所有文件看看
print(f"未找到会话 {session_name} 的文件")
print(f"数据目录 {data_dir} 中的文件:")
for f in data_path.glob("*.h5*"):
print(f" {f.name}")
return None
def restore_phoneme_samples(pkl_path, data_dir, num_samples=5):
"""还原几个音素样本的神经信号数据"""
print("=== 加载转换后的音素数据集 ===")
phoneme_data = load_converted_phoneme_dataset(pkl_path)
print(f"音素数量: {len(phoneme_data)}")
print(f"第一个音素的键: {list(phoneme_data.keys())[:5]}")
# 选择几个样本进行还原
sample_count = 0
for phoneme, segments in phoneme_data.items():
if sample_count >= num_samples:
break
if phoneme == '|': # 跳过静音
continue
print(f"\n=== 处理音素 '{phoneme}' ===")
print(f"该音素有 {len(segments)} 个segments")
# 取第一个segment
segment = segments[0]
print(f"Segment信息:")
for key, value in segment.items():
if key != 'original_timestamps':
print(f" {key}: {value}")
# 显示时间戳转换结果
original_ts = segment['original_timestamps']
print(f"原始时间戳转换:")
print(f" 输出时间戳: {segment['start_time']}-{segment['end_time']}")
print(f" 简单映射: {original_ts['simple']['start']}-{original_ts['simple']['end']}")
print(f" 保守映射: {original_ts['conservative']['start']}-{original_ts['conservative']['end']}")
print(f" 可能映射: {original_ts['likely']['start']}-{original_ts['likely']['end']}")
# 尝试加载对应的会话数据
session_name = segment['session']
print(f"\n寻找会话文件: {session_name}")
session_file = find_session_file(session_name, data_dir)
if session_file is None:
print(f"未找到会话 {session_name} 的数据文件")
continue
print(f"找到会话文件: {session_file}")
try:
# 加载神经特征数据
neural_features = load_h5_session_data(session_file)
# 使用简单映射提取对应的神经信号
start_idx = original_ts['simple']['start']
end_idx = original_ts['simple']['end']
if end_idx < neural_features.shape[0]:
extracted_features = neural_features[start_idx:end_idx+1, :]
print(f"\n成功提取神经信号:")
print(f" 时间范围: {start_idx}-{end_idx} ({end_idx-start_idx+1} 个时间步)")
print(f" 提取数据形状: {extracted_features.shape}")
print(f" 时间长度: {(end_idx-start_idx+1) * 20}ms")
print(f" 特征统计:")
print(f" 均值: {np.mean(extracted_features):.4f}")
print(f" 标准差: {np.std(extracted_features):.4f}")
print(f" 最小值: {np.min(extracted_features):.4f}")
print(f" 最大值: {np.max(extracted_features):.4f}")
# 显示前几个特征的值
print(f" 前5个时间步的前10个特征:")
for t in range(min(5, extracted_features.shape[0])):
values = extracted_features[t, :10]
print(f" t+{t}: {values}")
sample_count += 1
else:
print(f"时间戳超出数据范围: {end_idx} >= {neural_features.shape[0]}")
except Exception as e:
print(f"处理会话 {session_name} 时出错: {e}")
continue
if __name__ == "__main__":
# 文件路径
pkl_path = "../phoneme_segmented_data/phoneme_dataset_20251009_202457_with_original_timestamps.pkl"
data_dir = "../data/hdf5_data_final"
print("=== 音素数据还原测试 ===")
print(f"音素数据集: {pkl_path}")
print(f"神经数据目录: {data_dir}")
restore_phoneme_samples(pkl_path, data_dir, num_samples=3)
print("\n=== 还原测试完成 ===")