328 lines
13 KiB
Python
328 lines
13 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
将PKL文件中的输出时间戳转换为原始数据时间戳
|
|||
|
支持三种映射方式:简单映射、保守映射、可能映射
|
|||
|
"""
|
|||
|
|
|||
|
import pickle
|
|||
|
import numpy as np
|
|||
|
from pathlib import Path
|
|||
|
from collections import defaultdict
|
|||
|
import argparse
|
|||
|
from datetime import datetime
|
|||
|
|
|||
|
# 滑动窗口参数 (来自 rnn_args.yaml)
|
|||
|
PATCH_SIZE = 14 # 滑动窗口大小
|
|||
|
PATCH_STRIDE = 4 # 滑动窗口步长
|
|||
|
ORIGINAL_BIN_MS = 20 # 原始时间bin大小(ms)
|
|||
|
|
|||
|
def convert_output_timestamp_to_original(output_start, output_end, patch_size=PATCH_SIZE, patch_stride=PATCH_STRIDE):
|
|||
|
"""
|
|||
|
将输出时间戳转换为原始数据时间戳,提供三种映射方式
|
|||
|
|
|||
|
Args:
|
|||
|
output_start: 输出序列中的开始时间步
|
|||
|
output_end: 输出序列中的结束时间步
|
|||
|
patch_size: 滑动窗口大小
|
|||
|
patch_stride: 滑动窗口步长
|
|||
|
|
|||
|
Returns:
|
|||
|
dict: 包含三种映射方式的原始时间戳信息
|
|||
|
"""
|
|||
|
|
|||
|
# 计算输出时间步对应的原始时间步中心位置
|
|||
|
original_start_center = output_start * patch_stride + (patch_size - 1) / 2
|
|||
|
original_end_center = output_end * patch_stride + (patch_size - 1) / 2
|
|||
|
|
|||
|
# 1. 简单映射:使用中心位置
|
|||
|
simple_start = int(round(original_start_center))
|
|||
|
simple_end = int(round(original_end_center))
|
|||
|
|
|||
|
# 2. 保守映射:考虑完整的patch范围
|
|||
|
patch_start_first = output_start * patch_stride
|
|||
|
patch_end_first = patch_start_first + patch_size - 1
|
|||
|
patch_start_last = output_end * patch_stride
|
|||
|
patch_end_last = patch_start_last + patch_size - 1
|
|||
|
|
|||
|
conservative_start = patch_start_first
|
|||
|
conservative_end = patch_end_last
|
|||
|
|
|||
|
# 3. 可能映射:基于中心位置但考虑patch边界的调整范围
|
|||
|
likely_start = max(patch_start_first, int(original_start_center - patch_stride/2))
|
|||
|
likely_end = min(patch_end_last, int(original_end_center + patch_stride/2))
|
|||
|
|
|||
|
return {
|
|||
|
'simple_mapping': (simple_start, simple_end),
|
|||
|
'conservative_mapping': (conservative_start, conservative_end),
|
|||
|
'likely_mapping': (likely_start, likely_end),
|
|||
|
'metadata': {
|
|||
|
'output_range': (output_start, output_end),
|
|||
|
'output_duration': output_end - output_start + 1,
|
|||
|
'center_positions': (original_start_center, original_end_center),
|
|||
|
'patch_ranges': {
|
|||
|
'first_patch': (patch_start_first, patch_end_first),
|
|||
|
'last_patch': (patch_start_last, patch_end_last)
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
def process_phoneme_dataset(input_file_path, output_file_path):
|
|||
|
"""
|
|||
|
处理phoneme dataset,转换时间戳并保存结果
|
|||
|
|
|||
|
Args:
|
|||
|
input_file_path: 输入PKL文件路径
|
|||
|
output_file_path: 输出PKL文件路径
|
|||
|
"""
|
|||
|
|
|||
|
print(f"=== 处理phoneme dataset ===")
|
|||
|
print(f"输入文件: {input_file_path}")
|
|||
|
print(f"输出文件: {output_file_path}")
|
|||
|
|
|||
|
# 加载原始数据
|
|||
|
with open(input_file_path, 'rb') as f:
|
|||
|
phoneme_data = pickle.load(f)
|
|||
|
|
|||
|
print(f"原始数据类型: {type(phoneme_data)}")
|
|||
|
if isinstance(phoneme_data, dict):
|
|||
|
print(f"音素数量: {len(phoneme_data)}")
|
|||
|
total_segments = sum(len(segments) for segments in phoneme_data.values())
|
|||
|
print(f"总segment数量: {total_segments}")
|
|||
|
|
|||
|
# 转换数据结构
|
|||
|
converted_data = {}
|
|||
|
conversion_stats = {
|
|||
|
'total_segments': 0,
|
|||
|
'conversion_errors': 0,
|
|||
|
'phoneme_counts': defaultdict(int)
|
|||
|
}
|
|||
|
|
|||
|
for phoneme, segments in phoneme_data.items():
|
|||
|
converted_segments = []
|
|||
|
|
|||
|
for segment in segments:
|
|||
|
try:
|
|||
|
# 获取原始时间戳
|
|||
|
output_start = segment['start_time']
|
|||
|
output_end = segment['end_time']
|
|||
|
|
|||
|
# 转换时间戳
|
|||
|
conversion = convert_output_timestamp_to_original(output_start, output_end)
|
|||
|
|
|||
|
# 创建新的segment,保留原有信息并添加转换结果
|
|||
|
new_segment = segment.copy()
|
|||
|
|
|||
|
# 添加原始时间戳字段
|
|||
|
new_segment['original_timestamps'] = {
|
|||
|
'simple_start': conversion['simple_mapping'][0],
|
|||
|
'simple_end': conversion['simple_mapping'][1],
|
|||
|
'conservative_start': conversion['conservative_mapping'][0],
|
|||
|
'conservative_end': conversion['conservative_mapping'][1],
|
|||
|
'likely_start': conversion['likely_mapping'][0],
|
|||
|
'likely_end': conversion['likely_mapping'][1]
|
|||
|
}
|
|||
|
|
|||
|
# 添加时长信息(毫秒)
|
|||
|
output_duration_ms = conversion['metadata']['output_duration'] * PATCH_STRIDE * ORIGINAL_BIN_MS
|
|||
|
new_segment['duration_info'] = {
|
|||
|
'output_duration_steps': conversion['metadata']['output_duration'],
|
|||
|
'output_duration_ms': output_duration_ms,
|
|||
|
'simple_duration_steps': conversion['simple_mapping'][1] - conversion['simple_mapping'][0] + 1,
|
|||
|
'conservative_duration_steps': conversion['conservative_mapping'][1] - conversion['conservative_mapping'][0] + 1,
|
|||
|
'likely_duration_steps': conversion['likely_mapping'][1] - conversion['likely_mapping'][0] + 1
|
|||
|
}
|
|||
|
|
|||
|
# 添加转换元数据
|
|||
|
new_segment['conversion_metadata'] = conversion['metadata']
|
|||
|
|
|||
|
converted_segments.append(new_segment)
|
|||
|
conversion_stats['total_segments'] += 1
|
|||
|
conversion_stats['phoneme_counts'][phoneme] += 1
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"转换segment时出错 (phoneme: {phoneme}): {e}")
|
|||
|
conversion_stats['conversion_errors'] += 1
|
|||
|
continue
|
|||
|
|
|||
|
converted_data[phoneme] = converted_segments
|
|||
|
|
|||
|
# 添加转换参数信息
|
|||
|
conversion_info = {
|
|||
|
'conversion_timestamp': datetime.now().isoformat(),
|
|||
|
'parameters': {
|
|||
|
'patch_size': PATCH_SIZE,
|
|||
|
'patch_stride': PATCH_STRIDE,
|
|||
|
'original_bin_ms': ORIGINAL_BIN_MS
|
|||
|
},
|
|||
|
'statistics': dict(conversion_stats),
|
|||
|
'mapping_methods': {
|
|||
|
'simple_mapping': '基于输出时间步中心位置的映射',
|
|||
|
'conservative_mapping': '基于完整patch范围的保守映射',
|
|||
|
'likely_mapping': '基于中心位置但考虑patch边界的调整映射'
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
# 保存结果
|
|||
|
output_data = {
|
|||
|
'phoneme_data': converted_data,
|
|||
|
'conversion_info': conversion_info
|
|||
|
}
|
|||
|
|
|||
|
with open(output_file_path, 'wb') as f:
|
|||
|
pickle.dump(output_data, f)
|
|||
|
|
|||
|
print(f"\n=== 转换完成 ===")
|
|||
|
print(f"成功转换: {conversion_stats['total_segments']} 个segments")
|
|||
|
print(f"转换错误: {conversion_stats['conversion_errors']} 个segments")
|
|||
|
print(f"音素分布:")
|
|||
|
|
|||
|
sorted_phonemes = sorted(conversion_stats['phoneme_counts'].items(),
|
|||
|
key=lambda x: x[1], reverse=True)
|
|||
|
for phoneme, count in sorted_phonemes[:10]:
|
|||
|
print(f" {phoneme}: {count} segments")
|
|||
|
|
|||
|
return output_data
|
|||
|
|
|||
|
def analyze_conversion_results(converted_data_path):
|
|||
|
"""
|
|||
|
分析转换结果的统计信息
|
|||
|
|
|||
|
Args:
|
|||
|
converted_data_path: 转换后的PKL文件路径
|
|||
|
"""
|
|||
|
|
|||
|
print(f"\n=== 分析转换结果 ===")
|
|||
|
|
|||
|
with open(converted_data_path, 'rb') as f:
|
|||
|
data = pickle.load(f)
|
|||
|
|
|||
|
phoneme_data = data['phoneme_data']
|
|||
|
conversion_info = data['conversion_info']
|
|||
|
|
|||
|
print(f"转换时间: {conversion_info['conversion_timestamp']}")
|
|||
|
print(f"转换参数: {conversion_info['parameters']}")
|
|||
|
|
|||
|
# 统计分析
|
|||
|
all_segments = []
|
|||
|
for phoneme, segments in phoneme_data.items():
|
|||
|
all_segments.extend(segments)
|
|||
|
|
|||
|
print(f"总segment数量: {len(all_segments)}")
|
|||
|
|
|||
|
# 分析时长分布
|
|||
|
output_durations = []
|
|||
|
simple_durations = []
|
|||
|
conservative_durations = []
|
|||
|
likely_durations = []
|
|||
|
|
|||
|
for segment in all_segments:
|
|||
|
duration_info = segment['duration_info']
|
|||
|
output_durations.append(duration_info['output_duration_steps'])
|
|||
|
simple_durations.append(duration_info['simple_duration_steps'])
|
|||
|
conservative_durations.append(duration_info['conservative_duration_steps'])
|
|||
|
likely_durations.append(duration_info['likely_duration_steps'])
|
|||
|
|
|||
|
print(f"\n时长统计 (时间步):")
|
|||
|
print(f"输出时长: 平均 {np.mean(output_durations):.1f}, 中位数 {np.median(output_durations):.1f}")
|
|||
|
print(f"简单映射时长: 平均 {np.mean(simple_durations):.1f}, 中位数 {np.median(simple_durations):.1f}")
|
|||
|
print(f"保守映射时长: 平均 {np.mean(conservative_durations):.1f}, 中位数 {np.median(conservative_durations):.1f}")
|
|||
|
print(f"可能映射时长: 平均 {np.mean(likely_durations):.1f}, 中位数 {np.median(likely_durations):.1f}")
|
|||
|
|
|||
|
# 计算映射比例
|
|||
|
simple_ratios = [s/o for s, o in zip(simple_durations, output_durations) if o > 0]
|
|||
|
conservative_ratios = [c/o for c, o in zip(conservative_durations, output_durations) if o > 0]
|
|||
|
likely_ratios = [l/o for l, o in zip(likely_durations, output_durations) if o > 0]
|
|||
|
|
|||
|
print(f"\n映射比例 (原始/输出):")
|
|||
|
print(f"简单映射: 平均 {np.mean(simple_ratios):.1f}x")
|
|||
|
print(f"保守映射: 平均 {np.mean(conservative_ratios):.1f}x")
|
|||
|
print(f"可能映射: 平均 {np.mean(likely_ratios):.1f}x")
|
|||
|
|
|||
|
# 显示几个示例
|
|||
|
print(f"\n=== 转换示例 ===")
|
|||
|
sample_segments = all_segments[:5]
|
|||
|
|
|||
|
print(f"{'音素':4s} {'输出时间戳':12s} {'简单映射':12s} {'保守映射':12s} {'可能映射':12s} {'时长(ms)':8s}")
|
|||
|
print("-" * 70)
|
|||
|
|
|||
|
for segment in sample_segments:
|
|||
|
phoneme = segment['phoneme']
|
|||
|
output_range = f"{segment['start_time']}-{segment['end_time']}"
|
|||
|
|
|||
|
timestamps = segment['original_timestamps']
|
|||
|
simple = f"{timestamps['simple_start']}-{timestamps['simple_end']}"
|
|||
|
conservative = f"{timestamps['conservative_start']}-{timestamps['conservative_end']}"
|
|||
|
likely = f"{timestamps['likely_start']}-{timestamps['likely_end']}"
|
|||
|
|
|||
|
duration_ms = segment['duration_info']['output_duration_ms']
|
|||
|
|
|||
|
print(f"{phoneme:4s} {output_range:12s} {simple:12s} {conservative:12s} {likely:12s} {duration_ms:6.0f}")
|
|||
|
|
|||
|
def main():
|
|||
|
"""主函数"""
|
|||
|
parser = argparse.ArgumentParser(description='转换phoneme dataset的时间戳')
|
|||
|
parser.add_argument('--input_dir', type=str, default='../phoneme_segmented_data',
|
|||
|
help='输入PKL文件目录')
|
|||
|
parser.add_argument('--output_dir', type=str, default='../phoneme_segmented_data',
|
|||
|
help='输出PKL文件目录')
|
|||
|
parser.add_argument('--input_file', type=str, default=None,
|
|||
|
help='指定输入文件名,如果不指定则处理最新的phoneme_dataset文件')
|
|||
|
parser.add_argument('--output_suffix', type=str, default='_with_original_timestamps',
|
|||
|
help='输出文件后缀')
|
|||
|
parser.add_argument('--analyze_only', action='store_true',
|
|||
|
help='只分析已存在的转换结果,不进行转换')
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
input_dir = Path(args.input_dir)
|
|||
|
output_dir = Path(args.output_dir)
|
|||
|
|
|||
|
# 确保输出目录存在
|
|||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|||
|
if args.analyze_only:
|
|||
|
# 只进行分析
|
|||
|
converted_files = list(output_dir.glob(f"*{args.output_suffix}.pkl"))
|
|||
|
if converted_files:
|
|||
|
latest_converted = max(converted_files, key=lambda x: x.stat().st_mtime)
|
|||
|
analyze_conversion_results(latest_converted)
|
|||
|
else:
|
|||
|
print("未找到转换后的文件")
|
|||
|
return
|
|||
|
|
|||
|
# 找到输入文件
|
|||
|
if args.input_file:
|
|||
|
input_file = input_dir / args.input_file
|
|||
|
else:
|
|||
|
# 找到最新的phoneme_dataset文件
|
|||
|
phoneme_files = list(input_dir.glob("phoneme_dataset_*.pkl"))
|
|||
|
if not phoneme_files:
|
|||
|
print(f"在目录 {input_dir} 中未找到phoneme_dataset文件")
|
|||
|
return
|
|||
|
|
|||
|
input_file = max(phoneme_files, key=lambda x: x.stat().st_mtime)
|
|||
|
|
|||
|
if not input_file.exists():
|
|||
|
print(f"输入文件不存在: {input_file}")
|
|||
|
return
|
|||
|
|
|||
|
# 生成输出文件名
|
|||
|
output_filename = input_file.stem + args.output_suffix + '.pkl'
|
|||
|
output_file = output_dir / output_filename
|
|||
|
|
|||
|
try:
|
|||
|
# 执行转换
|
|||
|
converted_data = process_phoneme_dataset(input_file, output_file)
|
|||
|
|
|||
|
# 分析结果
|
|||
|
analyze_conversion_results(output_file)
|
|||
|
|
|||
|
print(f"\n转换完成!结果已保存到: {output_file}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"处理过程中出现错误: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|