299 lines
11 KiB
Python
299 lines
11 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
简化版TTA-E参数搜索
|
|||
|
专门搜索6个关键参数:gru_weight 和 5个TTA权重
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import argparse
|
|||
|
import json
|
|||
|
import numpy as np
|
|||
|
from itertools import product
|
|||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
|
import subprocess
|
|||
|
import time
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
def parse_arguments():
|
|||
|
parser = argparse.ArgumentParser(description='简化版TTA-E参数搜索')
|
|||
|
|
|||
|
# 搜索空间定义 - 精度0.1
|
|||
|
parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
|
|||
|
help='GRU权重搜索范围')
|
|||
|
parser.add_argument('--original_weights', type=str, default='1.0',
|
|||
|
help='原始数据权重(通常固定为1.0)')
|
|||
|
parser.add_argument('--noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
|
|||
|
help='噪声增强权重搜索范围')
|
|||
|
parser.add_argument('--scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
|
|||
|
help='缩放增强权重搜索范围')
|
|||
|
parser.add_argument('--shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
|
|||
|
help='偏移增强权重搜索范围')
|
|||
|
parser.add_argument('--smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
|
|||
|
help='平滑增强权重搜索范围')
|
|||
|
|
|||
|
# 基础评估参数
|
|||
|
parser.add_argument('--base_script', type=str, default='evaluate_model.py',
|
|||
|
help='基础评估脚本路径')
|
|||
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
|||
|
help='数据目录')
|
|||
|
parser.add_argument('--eval_type', type=str, default='val',
|
|||
|
help='评估类型')
|
|||
|
parser.add_argument('--gpu_number', type=int, default=0,
|
|||
|
help='GPU编号')
|
|||
|
|
|||
|
# 输出控制
|
|||
|
parser.add_argument('--output_file', type=str, default='parameter_search_results.json',
|
|||
|
help='搜索结果输出文件')
|
|||
|
parser.add_argument('--dry_run', action='store_true',
|
|||
|
help='只显示搜索空间,不实际运行')
|
|||
|
parser.add_argument('--max_workers', type=int, default=25,
|
|||
|
help='最大并行工作线程数')
|
|||
|
parser.add_argument('--batch_size', type=int, default=100,
|
|||
|
help='每批处理的配置数量')
|
|||
|
|
|||
|
return parser.parse_args()
|
|||
|
|
|||
|
def generate_search_space(args):
|
|||
|
"""生成搜索空间"""
|
|||
|
gru_weights = [float(x.strip()) for x in args.gru_weights.split(',')]
|
|||
|
original_weights = [float(x.strip()) for x in args.original_weights.split(',')]
|
|||
|
noise_weights = [float(x.strip()) for x in args.noise_weights.split(',')]
|
|||
|
scale_weights = [float(x.strip()) for x in args.scale_weights.split(',')]
|
|||
|
shift_weights = [float(x.strip()) for x in args.shift_weights.split(',')]
|
|||
|
smooth_weights = [float(x.strip()) for x in args.smooth_weights.split(',')]
|
|||
|
|
|||
|
search_space = list(product(
|
|||
|
gru_weights, original_weights, noise_weights,
|
|||
|
scale_weights, shift_weights, smooth_weights
|
|||
|
))
|
|||
|
|
|||
|
return search_space
|
|||
|
|
|||
|
def run_single_evaluation(config, args):
|
|||
|
"""运行单个配置的评估"""
|
|||
|
gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config
|
|||
|
|
|||
|
# 构建TTA权重字符串
|
|||
|
tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}"
|
|||
|
|
|||
|
# 构建命令
|
|||
|
cmd = [
|
|||
|
'python', args.base_script,
|
|||
|
'--gru_weight', str(gru_w),
|
|||
|
'--tta_weights', tta_weights_str,
|
|||
|
'--data_dir', args.data_dir,
|
|||
|
'--eval_type', args.eval_type,
|
|||
|
'--gpu_number', str(args.gpu_number)
|
|||
|
]
|
|||
|
|
|||
|
if args.dry_run:
|
|||
|
print(f"Would run: {' '.join(cmd)}")
|
|||
|
return {
|
|||
|
'config': config,
|
|||
|
'gru_weight': gru_w,
|
|||
|
'tta_weights': {
|
|||
|
'original': orig_w,
|
|||
|
'noise': noise_w,
|
|||
|
'scale': scale_w,
|
|||
|
'shift': shift_w,
|
|||
|
'smooth': smooth_w
|
|||
|
},
|
|||
|
'per': np.random.uniform(20, 40), # 模拟PER结果
|
|||
|
'command': ' '.join(cmd)
|
|||
|
}
|
|||
|
|
|||
|
# 实际运行命令
|
|||
|
import subprocess
|
|||
|
import tempfile
|
|||
|
|
|||
|
try:
|
|||
|
# 捕获输出
|
|||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 增加超时到30分钟
|
|||
|
|
|||
|
# 解析PER结果
|
|||
|
per = None
|
|||
|
for line in result.stdout.split('\n'):
|
|||
|
if 'Aggregate Phoneme Error Rate (PER):' in line:
|
|||
|
# 提取百分号前的数字
|
|||
|
try:
|
|||
|
per_str = line.split('Aggregate Phoneme Error Rate (PER):')[-1].strip()
|
|||
|
per_str = per_str.replace('%', '').strip()
|
|||
|
per = float(per_str)
|
|||
|
break
|
|||
|
except (ValueError, IndexError) as e:
|
|||
|
print(f"Error parsing PER from line: {line}, error: {e}")
|
|||
|
continue
|
|||
|
|
|||
|
if per is None:
|
|||
|
print(f"Warning: Could not parse PER from output for config {config}")
|
|||
|
per = float('inf')
|
|||
|
|
|||
|
return {
|
|||
|
'config': config,
|
|||
|
'gru_weight': gru_w,
|
|||
|
'tta_weights': {
|
|||
|
'original': orig_w,
|
|||
|
'noise': noise_w,
|
|||
|
'scale': scale_w,
|
|||
|
'shift': shift_w,
|
|||
|
'smooth': smooth_w
|
|||
|
},
|
|||
|
'per': per,
|
|||
|
'command': ' '.join(cmd),
|
|||
|
'success': result.returncode == 0
|
|||
|
}
|
|||
|
|
|||
|
except subprocess.TimeoutExpired:
|
|||
|
return {
|
|||
|
'config': config,
|
|||
|
'per': float('inf'),
|
|||
|
'error': 'Timeout',
|
|||
|
'command': ' '.join(cmd)
|
|||
|
}
|
|||
|
except Exception as e:
|
|||
|
return {
|
|||
|
'config': config,
|
|||
|
'per': float('inf'),
|
|||
|
'error': str(e),
|
|||
|
'command': ' '.join(cmd)
|
|||
|
}
|
|||
|
|
|||
|
def main():
|
|||
|
args = parse_arguments()
|
|||
|
|
|||
|
print("🔍 TTA-E参数搜索")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
# 生成搜索空间
|
|||
|
search_space = generate_search_space(args)
|
|||
|
total_configs = len(search_space)
|
|||
|
|
|||
|
print(f"搜索空间大小: {total_configs} 个配置")
|
|||
|
print(f"参数范围:")
|
|||
|
print(f" GRU权重: {args.gru_weights}")
|
|||
|
print(f" 原始权重: {args.original_weights}")
|
|||
|
print(f" 噪声权重: {args.noise_weights}")
|
|||
|
print(f" 缩放权重: {args.scale_weights}")
|
|||
|
print(f" 偏移权重: {args.shift_weights}")
|
|||
|
print(f" 平滑权重: {args.smooth_weights}")
|
|||
|
print()
|
|||
|
|
|||
|
if args.dry_run:
|
|||
|
print("🧪 Dry run模式 - 显示前5个配置示例:")
|
|||
|
for i, config in enumerate(search_space[:5]):
|
|||
|
result = run_single_evaluation(config, args)
|
|||
|
print(f"{i+1}. {result['command']}")
|
|||
|
print(f"\n总共会运行 {total_configs} 个配置")
|
|||
|
return
|
|||
|
|
|||
|
# 运行搜索
|
|||
|
print("🚀 开始参数搜索...")
|
|||
|
print(f"使用 {args.max_workers} 个线程并行处理...")
|
|||
|
|
|||
|
results = []
|
|||
|
best_per = float('inf')
|
|||
|
best_config = None
|
|||
|
completed_count = 0
|
|||
|
|
|||
|
# 使用线程池并行处理
|
|||
|
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
|||
|
# 提交所有任务
|
|||
|
future_to_config = {
|
|||
|
executor.submit(run_single_evaluation, config, args): config
|
|||
|
for config in search_space
|
|||
|
}
|
|||
|
|
|||
|
# 处理完成的任务
|
|||
|
for future in as_completed(future_to_config):
|
|||
|
try:
|
|||
|
result = future.result()
|
|||
|
results.append(result)
|
|||
|
completed_count += 1
|
|||
|
|
|||
|
# 更新最佳结果
|
|||
|
if result['per'] < best_per:
|
|||
|
best_per = result['per']
|
|||
|
best_config = result
|
|||
|
config = future_to_config[future]
|
|||
|
print(f"\n🎯 新最优配置[{completed_count}/{total_configs}]: PER={best_per:.3f}%")
|
|||
|
print(f" GRU={config[0]:.1f}, TTA=({config[1]},{config[2]},{config[3]},{config[4]},{config[5]})")
|
|||
|
|
|||
|
# 定期进度报告
|
|||
|
if completed_count % args.batch_size == 0:
|
|||
|
progress = 100 * completed_count / total_configs
|
|||
|
print(f"\n📊 进度: {completed_count}/{total_configs} ({progress:.1f}%)")
|
|||
|
print(f" 当前最优PER: {best_per:.3f}%")
|
|||
|
elif completed_count % 50 == 0: # 更频繁的简单进度
|
|||
|
print(f"... {completed_count}/{total_configs} ...", end='', flush=True)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
completed_count += 1
|
|||
|
config = future_to_config[future]
|
|||
|
print(f"\n❌ 配置失败: {config}, 错误: {e}")
|
|||
|
results.append({
|
|||
|
'config': config,
|
|||
|
'per': float('inf'),
|
|||
|
'error': str(e)
|
|||
|
})
|
|||
|
|
|||
|
print(f"\n✅ 所有任务完成!")
|
|||
|
|
|||
|
# 找到真正的最佳结果(防止异常情况)
|
|||
|
valid_results = [r for r in results if 'error' not in r and r['per'] != float('inf')]
|
|||
|
if valid_results:
|
|||
|
best_config = min(valid_results, key=lambda x: x['per'])
|
|||
|
|
|||
|
# 保存结果
|
|||
|
search_results = {
|
|||
|
'best_config': best_config,
|
|||
|
'all_results': results,
|
|||
|
'search_space_size': total_configs,
|
|||
|
'args': vars(args),
|
|||
|
'timestamp': __import__('time').strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
}
|
|||
|
|
|||
|
with open(args.output_file, 'w') as f:
|
|||
|
json.dump(search_results, f, indent=2)
|
|||
|
|
|||
|
# 总结结果
|
|||
|
print("\n" + "=" * 50)
|
|||
|
print("🏆 搜索完成!")
|
|||
|
|
|||
|
if best_config is not None:
|
|||
|
print(f"最佳配置:")
|
|||
|
print(f" PER: {best_config['per']:.3f}%")
|
|||
|
print(f" GRU权重: {best_config['gru_weight']:.1f}")
|
|||
|
print(f" TTA权重: {best_config['tta_weights']}")
|
|||
|
print(f" 命令: {best_config['command']}")
|
|||
|
|
|||
|
# 显示前10个最佳结果
|
|||
|
sorted_results = sorted([r for r in results if r['per'] != float('inf')],
|
|||
|
key=lambda x: x['per'])
|
|||
|
|
|||
|
print(f"\n📊 前10个最佳配置:")
|
|||
|
print("排名 | PER(%) | GRU | Original | Noise | Scale | Shift | Smooth")
|
|||
|
print("-" * 70)
|
|||
|
for i, result in enumerate(sorted_results[:10]):
|
|||
|
tw = result['tta_weights']
|
|||
|
print(f"{i+1:3d} | {result['per']:6.3f} | {result['gru_weight']:3.1f} | "
|
|||
|
f"{tw['original']:8.1f} | {tw['noise']:5.1f} | {tw['scale']:5.1f} | "
|
|||
|
f"{tw['shift']:5.1f} | {tw['smooth']:6.1f}")
|
|||
|
else:
|
|||
|
print("❌ 未找到有效的配置结果!所有配置都失败了。")
|
|||
|
print("请检查评估脚本是否正常工作。")
|
|||
|
|
|||
|
print(f"\n📈 搜索统计:")
|
|||
|
print(f" 总配置数: {total_configs}")
|
|||
|
print(f" 成功配置数: {len(valid_results)}")
|
|||
|
print(f" 失败配置数: {total_configs - len(valid_results)}")
|
|||
|
|
|||
|
if valid_results:
|
|||
|
valid_pers = [r['per'] for r in valid_results]
|
|||
|
print(f" PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%")
|
|||
|
print(f" 平均PER: {sum(valid_pers)/len(valid_pers):.3f}%")
|
|||
|
|
|||
|
print(f"\n✅ 结果已保存到: {args.output_file}")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|