Files
b2txt25/data_analyse/convert_timestamps_to_original.py

328 lines
13 KiB
Python
Raw Normal View History

2025-10-12 09:11:32 +08:00
#!/usr/bin/env python3
"""
将PKL文件中的输出时间戳转换为原始数据时间戳
支持三种映射方式简单映射保守映射可能映射
"""
import pickle
import numpy as np
from pathlib import Path
from collections import defaultdict
import argparse
from datetime import datetime
# 滑动窗口参数 (来自 rnn_args.yaml)
PATCH_SIZE = 14 # 滑动窗口大小
PATCH_STRIDE = 4 # 滑动窗口步长
ORIGINAL_BIN_MS = 20 # 原始时间bin大小(ms)
def convert_output_timestamp_to_original(output_start, output_end, patch_size=PATCH_SIZE, patch_stride=PATCH_STRIDE):
"""
将输出时间戳转换为原始数据时间戳提供三种映射方式
Args:
output_start: 输出序列中的开始时间步
output_end: 输出序列中的结束时间步
patch_size: 滑动窗口大小
patch_stride: 滑动窗口步长
Returns:
dict: 包含三种映射方式的原始时间戳信息
"""
# 计算输出时间步对应的原始时间步中心位置
original_start_center = output_start * patch_stride + (patch_size - 1) / 2
original_end_center = output_end * patch_stride + (patch_size - 1) / 2
# 1. 简单映射:使用中心位置
simple_start = int(round(original_start_center))
simple_end = int(round(original_end_center))
# 2. 保守映射考虑完整的patch范围
patch_start_first = output_start * patch_stride
patch_end_first = patch_start_first + patch_size - 1
patch_start_last = output_end * patch_stride
patch_end_last = patch_start_last + patch_size - 1
conservative_start = patch_start_first
conservative_end = patch_end_last
# 3. 可能映射基于中心位置但考虑patch边界的调整范围
likely_start = max(patch_start_first, int(original_start_center - patch_stride/2))
likely_end = min(patch_end_last, int(original_end_center + patch_stride/2))
return {
'simple_mapping': (simple_start, simple_end),
'conservative_mapping': (conservative_start, conservative_end),
'likely_mapping': (likely_start, likely_end),
'metadata': {
'output_range': (output_start, output_end),
'output_duration': output_end - output_start + 1,
'center_positions': (original_start_center, original_end_center),
'patch_ranges': {
'first_patch': (patch_start_first, patch_end_first),
'last_patch': (patch_start_last, patch_end_last)
}
}
}
def process_phoneme_dataset(input_file_path, output_file_path):
"""
处理phoneme dataset转换时间戳并保存结果
Args:
input_file_path: 输入PKL文件路径
output_file_path: 输出PKL文件路径
"""
print(f"=== 处理phoneme dataset ===")
print(f"输入文件: {input_file_path}")
print(f"输出文件: {output_file_path}")
# 加载原始数据
with open(input_file_path, 'rb') as f:
phoneme_data = pickle.load(f)
print(f"原始数据类型: {type(phoneme_data)}")
if isinstance(phoneme_data, dict):
print(f"音素数量: {len(phoneme_data)}")
total_segments = sum(len(segments) for segments in phoneme_data.values())
print(f"总segment数量: {total_segments}")
# 转换数据结构
converted_data = {}
conversion_stats = {
'total_segments': 0,
'conversion_errors': 0,
'phoneme_counts': defaultdict(int)
}
for phoneme, segments in phoneme_data.items():
converted_segments = []
for segment in segments:
try:
# 获取原始时间戳
output_start = segment['start_time']
output_end = segment['end_time']
# 转换时间戳
conversion = convert_output_timestamp_to_original(output_start, output_end)
# 创建新的segment保留原有信息并添加转换结果
new_segment = segment.copy()
# 添加原始时间戳字段
new_segment['original_timestamps'] = {
'simple_start': conversion['simple_mapping'][0],
'simple_end': conversion['simple_mapping'][1],
'conservative_start': conversion['conservative_mapping'][0],
'conservative_end': conversion['conservative_mapping'][1],
'likely_start': conversion['likely_mapping'][0],
'likely_end': conversion['likely_mapping'][1]
}
# 添加时长信息(毫秒)
output_duration_ms = conversion['metadata']['output_duration'] * PATCH_STRIDE * ORIGINAL_BIN_MS
new_segment['duration_info'] = {
'output_duration_steps': conversion['metadata']['output_duration'],
'output_duration_ms': output_duration_ms,
'simple_duration_steps': conversion['simple_mapping'][1] - conversion['simple_mapping'][0] + 1,
'conservative_duration_steps': conversion['conservative_mapping'][1] - conversion['conservative_mapping'][0] + 1,
'likely_duration_steps': conversion['likely_mapping'][1] - conversion['likely_mapping'][0] + 1
}
# 添加转换元数据
new_segment['conversion_metadata'] = conversion['metadata']
converted_segments.append(new_segment)
conversion_stats['total_segments'] += 1
conversion_stats['phoneme_counts'][phoneme] += 1
except Exception as e:
print(f"转换segment时出错 (phoneme: {phoneme}): {e}")
conversion_stats['conversion_errors'] += 1
continue
converted_data[phoneme] = converted_segments
# 添加转换参数信息
conversion_info = {
'conversion_timestamp': datetime.now().isoformat(),
'parameters': {
'patch_size': PATCH_SIZE,
'patch_stride': PATCH_STRIDE,
'original_bin_ms': ORIGINAL_BIN_MS
},
'statistics': dict(conversion_stats),
'mapping_methods': {
'simple_mapping': '基于输出时间步中心位置的映射',
'conservative_mapping': '基于完整patch范围的保守映射',
'likely_mapping': '基于中心位置但考虑patch边界的调整映射'
}
}
# 保存结果
output_data = {
'phoneme_data': converted_data,
'conversion_info': conversion_info
}
with open(output_file_path, 'wb') as f:
pickle.dump(output_data, f)
print(f"\n=== 转换完成 ===")
print(f"成功转换: {conversion_stats['total_segments']} 个segments")
print(f"转换错误: {conversion_stats['conversion_errors']} 个segments")
print(f"音素分布:")
sorted_phonemes = sorted(conversion_stats['phoneme_counts'].items(),
key=lambda x: x[1], reverse=True)
for phoneme, count in sorted_phonemes[:10]:
print(f" {phoneme}: {count} segments")
return output_data
def analyze_conversion_results(converted_data_path):
"""
分析转换结果的统计信息
Args:
converted_data_path: 转换后的PKL文件路径
"""
print(f"\n=== 分析转换结果 ===")
with open(converted_data_path, 'rb') as f:
data = pickle.load(f)
phoneme_data = data['phoneme_data']
conversion_info = data['conversion_info']
print(f"转换时间: {conversion_info['conversion_timestamp']}")
print(f"转换参数: {conversion_info['parameters']}")
# 统计分析
all_segments = []
for phoneme, segments in phoneme_data.items():
all_segments.extend(segments)
print(f"总segment数量: {len(all_segments)}")
# 分析时长分布
output_durations = []
simple_durations = []
conservative_durations = []
likely_durations = []
for segment in all_segments:
duration_info = segment['duration_info']
output_durations.append(duration_info['output_duration_steps'])
simple_durations.append(duration_info['simple_duration_steps'])
conservative_durations.append(duration_info['conservative_duration_steps'])
likely_durations.append(duration_info['likely_duration_steps'])
print(f"\n时长统计 (时间步):")
print(f"输出时长: 平均 {np.mean(output_durations):.1f}, 中位数 {np.median(output_durations):.1f}")
print(f"简单映射时长: 平均 {np.mean(simple_durations):.1f}, 中位数 {np.median(simple_durations):.1f}")
print(f"保守映射时长: 平均 {np.mean(conservative_durations):.1f}, 中位数 {np.median(conservative_durations):.1f}")
print(f"可能映射时长: 平均 {np.mean(likely_durations):.1f}, 中位数 {np.median(likely_durations):.1f}")
# 计算映射比例
simple_ratios = [s/o for s, o in zip(simple_durations, output_durations) if o > 0]
conservative_ratios = [c/o for c, o in zip(conservative_durations, output_durations) if o > 0]
likely_ratios = [l/o for l, o in zip(likely_durations, output_durations) if o > 0]
print(f"\n映射比例 (原始/输出):")
print(f"简单映射: 平均 {np.mean(simple_ratios):.1f}x")
print(f"保守映射: 平均 {np.mean(conservative_ratios):.1f}x")
print(f"可能映射: 平均 {np.mean(likely_ratios):.1f}x")
# 显示几个示例
print(f"\n=== 转换示例 ===")
sample_segments = all_segments[:5]
print(f"{'音素':4s} {'输出时间戳':12s} {'简单映射':12s} {'保守映射':12s} {'可能映射':12s} {'时长(ms)':8s}")
print("-" * 70)
for segment in sample_segments:
phoneme = segment['phoneme']
output_range = f"{segment['start_time']}-{segment['end_time']}"
timestamps = segment['original_timestamps']
simple = f"{timestamps['simple_start']}-{timestamps['simple_end']}"
conservative = f"{timestamps['conservative_start']}-{timestamps['conservative_end']}"
likely = f"{timestamps['likely_start']}-{timestamps['likely_end']}"
duration_ms = segment['duration_info']['output_duration_ms']
print(f"{phoneme:4s} {output_range:12s} {simple:12s} {conservative:12s} {likely:12s} {duration_ms:6.0f}")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='转换phoneme dataset的时间戳')
parser.add_argument('--input_dir', type=str, default='../phoneme_segmented_data',
help='输入PKL文件目录')
parser.add_argument('--output_dir', type=str, default='../phoneme_segmented_data',
help='输出PKL文件目录')
parser.add_argument('--input_file', type=str, default=None,
help='指定输入文件名如果不指定则处理最新的phoneme_dataset文件')
parser.add_argument('--output_suffix', type=str, default='_with_original_timestamps',
help='输出文件后缀')
parser.add_argument('--analyze_only', action='store_true',
help='只分析已存在的转换结果,不进行转换')
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
# 确保输出目录存在
output_dir.mkdir(parents=True, exist_ok=True)
if args.analyze_only:
# 只进行分析
converted_files = list(output_dir.glob(f"*{args.output_suffix}.pkl"))
if converted_files:
latest_converted = max(converted_files, key=lambda x: x.stat().st_mtime)
analyze_conversion_results(latest_converted)
else:
print("未找到转换后的文件")
return
# 找到输入文件
if args.input_file:
input_file = input_dir / args.input_file
else:
# 找到最新的phoneme_dataset文件
phoneme_files = list(input_dir.glob("phoneme_dataset_*.pkl"))
if not phoneme_files:
print(f"在目录 {input_dir} 中未找到phoneme_dataset文件")
return
input_file = max(phoneme_files, key=lambda x: x.stat().st_mtime)
if not input_file.exists():
print(f"输入文件不存在: {input_file}")
return
# 生成输出文件名
output_filename = input_file.stem + args.output_suffix + '.pkl'
output_file = output_dir / output_filename
try:
# 执行转换
converted_data = process_phoneme_dataset(input_file, output_file)
# 分析结果
analyze_conversion_results(output_file)
print(f"\n转换完成!结果已保存到: {output_file}")
except Exception as e:
print(f"处理过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()