Files
b2txt25/TTA-E/simple_search.py
2025-10-06 15:17:44 +08:00

299 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
简化版TTA-E参数搜索
专门搜索6个关键参数gru_weight 和 5个TTA权重
"""
import os
import sys
import argparse
import json
import numpy as np
from itertools import product
from concurrent.futures import ThreadPoolExecutor, as_completed
import subprocess
import time
from tqdm import tqdm
def parse_arguments():
parser = argparse.ArgumentParser(description='简化版TTA-E参数搜索')
# 搜索空间定义 - 精度0.1
parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='GRU权重搜索范围')
parser.add_argument('--original_weights', type=str, default='1.0',
help='原始数据权重通常固定为1.0')
parser.add_argument('--noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='噪声增强权重搜索范围')
parser.add_argument('--scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='缩放增强权重搜索范围')
parser.add_argument('--shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='偏移增强权重搜索范围')
parser.add_argument('--smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='平滑增强权重搜索范围')
# 基础评估参数
parser.add_argument('--base_script', type=str, default='evaluate_model.py',
help='基础评估脚本路径')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='数据目录')
parser.add_argument('--eval_type', type=str, default='val',
help='评估类型')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU编号')
# 输出控制
parser.add_argument('--output_file', type=str, default='parameter_search_results.json',
help='搜索结果输出文件')
parser.add_argument('--dry_run', action='store_true',
help='只显示搜索空间,不实际运行')
parser.add_argument('--max_workers', type=int, default=25,
help='最大并行工作线程数')
parser.add_argument('--batch_size', type=int, default=100,
help='每批处理的配置数量')
return parser.parse_args()
def generate_search_space(args):
"""生成搜索空间"""
gru_weights = [float(x.strip()) for x in args.gru_weights.split(',')]
original_weights = [float(x.strip()) for x in args.original_weights.split(',')]
noise_weights = [float(x.strip()) for x in args.noise_weights.split(',')]
scale_weights = [float(x.strip()) for x in args.scale_weights.split(',')]
shift_weights = [float(x.strip()) for x in args.shift_weights.split(',')]
smooth_weights = [float(x.strip()) for x in args.smooth_weights.split(',')]
search_space = list(product(
gru_weights, original_weights, noise_weights,
scale_weights, shift_weights, smooth_weights
))
return search_space
def run_single_evaluation(config, args):
"""运行单个配置的评估"""
gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config
# 构建TTA权重字符串
tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}"
# 构建命令
cmd = [
'python', args.base_script,
'--gru_weight', str(gru_w),
'--tta_weights', tta_weights_str,
'--data_dir', args.data_dir,
'--eval_type', args.eval_type,
'--gpu_number', str(args.gpu_number)
]
if args.dry_run:
print(f"Would run: {' '.join(cmd)}")
return {
'config': config,
'gru_weight': gru_w,
'tta_weights': {
'original': orig_w,
'noise': noise_w,
'scale': scale_w,
'shift': shift_w,
'smooth': smooth_w
},
'per': np.random.uniform(20, 40), # 模拟PER结果
'command': ' '.join(cmd)
}
# 实际运行命令
import subprocess
import tempfile
try:
# 捕获输出
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 增加超时到30分钟
# 解析PER结果
per = None
for line in result.stdout.split('\n'):
if 'Aggregate Phoneme Error Rate (PER):' in line:
# 提取百分号前的数字
try:
per_str = line.split('Aggregate Phoneme Error Rate (PER):')[-1].strip()
per_str = per_str.replace('%', '').strip()
per = float(per_str)
break
except (ValueError, IndexError) as e:
print(f"Error parsing PER from line: {line}, error: {e}")
continue
if per is None:
print(f"Warning: Could not parse PER from output for config {config}")
per = float('inf')
return {
'config': config,
'gru_weight': gru_w,
'tta_weights': {
'original': orig_w,
'noise': noise_w,
'scale': scale_w,
'shift': shift_w,
'smooth': smooth_w
},
'per': per,
'command': ' '.join(cmd),
'success': result.returncode == 0
}
except subprocess.TimeoutExpired:
return {
'config': config,
'per': float('inf'),
'error': 'Timeout',
'command': ' '.join(cmd)
}
except Exception as e:
return {
'config': config,
'per': float('inf'),
'error': str(e),
'command': ' '.join(cmd)
}
def main():
args = parse_arguments()
print("🔍 TTA-E参数搜索")
print("=" * 50)
# 生成搜索空间
search_space = generate_search_space(args)
total_configs = len(search_space)
print(f"搜索空间大小: {total_configs} 个配置")
print(f"参数范围:")
print(f" GRU权重: {args.gru_weights}")
print(f" 原始权重: {args.original_weights}")
print(f" 噪声权重: {args.noise_weights}")
print(f" 缩放权重: {args.scale_weights}")
print(f" 偏移权重: {args.shift_weights}")
print(f" 平滑权重: {args.smooth_weights}")
print()
if args.dry_run:
print("🧪 Dry run模式 - 显示前5个配置示例:")
for i, config in enumerate(search_space[:5]):
result = run_single_evaluation(config, args)
print(f"{i+1}. {result['command']}")
print(f"\n总共会运行 {total_configs} 个配置")
return
# 运行搜索
print("🚀 开始参数搜索...")
print(f"使用 {args.max_workers} 个线程并行处理...")
results = []
best_per = float('inf')
best_config = None
completed_count = 0
# 使用线程池并行处理
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
# 提交所有任务
future_to_config = {
executor.submit(run_single_evaluation, config, args): config
for config in search_space
}
# 处理完成的任务
for future in as_completed(future_to_config):
try:
result = future.result()
results.append(result)
completed_count += 1
# 更新最佳结果
if result['per'] < best_per:
best_per = result['per']
best_config = result
config = future_to_config[future]
print(f"\n🎯 新最优配置[{completed_count}/{total_configs}]: PER={best_per:.3f}%")
print(f" GRU={config[0]:.1f}, TTA=({config[1]},{config[2]},{config[3]},{config[4]},{config[5]})")
# 定期进度报告
if completed_count % args.batch_size == 0:
progress = 100 * completed_count / total_configs
print(f"\n📊 进度: {completed_count}/{total_configs} ({progress:.1f}%)")
print(f" 当前最优PER: {best_per:.3f}%")
elif completed_count % 50 == 0: # 更频繁的简单进度
print(f"... {completed_count}/{total_configs} ...", end='', flush=True)
except Exception as e:
completed_count += 1
config = future_to_config[future]
print(f"\n❌ 配置失败: {config}, 错误: {e}")
results.append({
'config': config,
'per': float('inf'),
'error': str(e)
})
print(f"\n✅ 所有任务完成!")
# 找到真正的最佳结果(防止异常情况)
valid_results = [r for r in results if 'error' not in r and r['per'] != float('inf')]
if valid_results:
best_config = min(valid_results, key=lambda x: x['per'])
# 保存结果
search_results = {
'best_config': best_config,
'all_results': results,
'search_space_size': total_configs,
'args': vars(args),
'timestamp': __import__('time').strftime("%Y-%m-%d %H:%M:%S")
}
with open(args.output_file, 'w') as f:
json.dump(search_results, f, indent=2)
# 总结结果
print("\n" + "=" * 50)
print("🏆 搜索完成!")
if best_config is not None:
print(f"最佳配置:")
print(f" PER: {best_config['per']:.3f}%")
print(f" GRU权重: {best_config['gru_weight']:.1f}")
print(f" TTA权重: {best_config['tta_weights']}")
print(f" 命令: {best_config['command']}")
# 显示前10个最佳结果
sorted_results = sorted([r for r in results if r['per'] != float('inf')],
key=lambda x: x['per'])
print(f"\n📊 前10个最佳配置:")
print("排名 | PER(%) | GRU | Original | Noise | Scale | Shift | Smooth")
print("-" * 70)
for i, result in enumerate(sorted_results[:10]):
tw = result['tta_weights']
print(f"{i+1:3d} | {result['per']:6.3f} | {result['gru_weight']:3.1f} | "
f"{tw['original']:8.1f} | {tw['noise']:5.1f} | {tw['scale']:5.1f} | "
f"{tw['shift']:5.1f} | {tw['smooth']:6.1f}")
else:
print("❌ 未找到有效的配置结果!所有配置都失败了。")
print("请检查评估脚本是否正常工作。")
print(f"\n📈 搜索统计:")
print(f" 总配置数: {total_configs}")
print(f" 成功配置数: {len(valid_results)}")
print(f" 失败配置数: {total_configs - len(valid_results)}")
if valid_results:
valid_pers = [r['per'] for r in valid_results]
print(f" PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%")
print(f" 平均PER: {sum(valid_pers)/len(valid_pers):.3f}%")
print(f"\n✅ 结果已保存到: {args.output_file}")
if __name__ == "__main__":
main()