159 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
		
		
			
		
	
	
			159 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
|   | # 增强 lm_decoder Python 绑定以支持时间戳提取
 | |||
|  | 
 | |||
|  | ## 问题
 | |||
|  | 当前 `lm_decoder` Python 绑定只暴露了句子和分数,没有暴露: | |||
|  | - Token 序列(inputs/outputs) | |||
|  | - 时间戳信息(times) | |||
|  | - 详细的似然度信息 | |||
|  | 
 | |||
|  | ## 解决方案
 | |||
|  | 
 | |||
|  | ### 步骤 1: 修改 brain_speech_decoder.h
 | |||
|  | 
 | |||
|  | 在 `BrainSpeechDecoder` 类中添加公有访问方法: | |||
|  | 
 | |||
|  | ```cpp | |||
|  | // 在 class BrainSpeechDecoder 的 public 部分添加 | |||
|  | 
 | |||
|  | const std::vector<std::vector<int>>& GetInputs() const { | |||
|  |     if (searcher_ == nullptr) { | |||
|  |         static std::vector<std::vector<int>> empty; | |||
|  |         return empty; | |||
|  |     } | |||
|  |     return searcher_->Inputs(); | |||
|  | } | |||
|  | 
 | |||
|  | const std::vector<std::vector<int>>& GetOutputs() const { | |||
|  |     if (searcher_ == nullptr) { | |||
|  |         static std::vector<std::vector<int>> empty; | |||
|  |         return empty; | |||
|  |     } | |||
|  |     return searcher_->Outputs(); | |||
|  | } | |||
|  | 
 | |||
|  | const std::vector<std::vector<int>>& GetTimes() const { | |||
|  |     if (searcher_ == nullptr) { | |||
|  |         static std::vector<std::vector<int>> empty; | |||
|  |         return empty; | |||
|  |     } | |||
|  |     return searcher_->Times(); | |||
|  | } | |||
|  | 
 | |||
|  | const std::vector<std::pair<float, float>>& GetLikelihood() const { | |||
|  |     if (searcher_ == nullptr) { | |||
|  |         static std::vector<std::pair<float, float>> empty; | |||
|  |         return empty; | |||
|  |     } | |||
|  |     return searcher_->Likelihood(); | |||
|  | } | |||
|  | ``` | |||
|  | 
 | |||
|  | ### 步骤 2: 修改 lm_decoder.cc Python 绑定
 | |||
|  | 
 | |||
|  | 在 `PYBIND11_MODULE` 中添加新的方法绑定: | |||
|  | 
 | |||
|  | ```cpp | |||
|  | py::class_<BrainSpeechDecoder>(m, "BrainSpeechDecoder") | |||
|  |     .def(py::init<std::shared_ptr<DecodeResource>, std::shared_ptr<DecodeOptions> >()) | |||
|  |     .def("SetOpt", &BrainSpeechDecoder::SetOpt) | |||
|  |     .def("Decode", &BrainSpeechDecoder::Decode) | |||
|  |     .def("Rescore", &BrainSpeechDecoder::Rescore) | |||
|  |     .def("Reset", &BrainSpeechDecoder::Reset) | |||
|  |     .def("FinishDecoding", &BrainSpeechDecoder::FinishDecoding) | |||
|  |     .def("DecodedSomething", &BrainSpeechDecoder::DecodedSomething) | |||
|  |     .def("result", &BrainSpeechDecoder::result) | |||
|  |     // 新增方法 | |||
|  |     .def("get_inputs", &BrainSpeechDecoder::GetInputs, | |||
|  |          "Get input token sequences for N-best hypotheses") | |||
|  |     .def("get_outputs", &BrainSpeechDecoder::GetOutputs, | |||
|  |          "Get output token sequences for N-best hypotheses") | |||
|  |     .def("get_times", &BrainSpeechDecoder::GetTimes, | |||
|  |          "Get timestamps for each token in N-best hypotheses") | |||
|  |     .def("get_likelihood", &BrainSpeechDecoder::GetLikelihood, | |||
|  |          "Get (acoustic_score, lm_score) pairs for N-best hypotheses"); | |||
|  | ``` | |||
|  | 
 | |||
|  | ### 步骤 3: 重新编译
 | |||
|  | 
 | |||
|  | ```bash | |||
|  | cd language_model/runtime/server/x86 | |||
|  | mkdir -p build && cd build | |||
|  | cmake .. | |||
|  | make -j$(nproc) | |||
|  | ``` | |||
|  | 
 | |||
|  | ### 步骤 4: 使用增强接口
 | |||
|  | 
 | |||
|  | 修改 `language-model-standalone.py`: | |||
|  | 
 | |||
|  | ```python | |||
|  | # 在 Finalize 阶段获取详细信息
 | |||
|  | if nbest > 1: | |||
|  |     # 获取基本结果 | |||
|  |     nbest_out = [] | |||
|  |     for d in ngramDecoder.result(): | |||
|  |         nbest_out.append([d.sentence, d.ac_score, d.lm_score]) | |||
|  |      | |||
|  |     # 获取时间戳和token序列(新增) | |||
|  |     try: | |||
|  |         inputs = ngramDecoder.get_inputs()      # List[List[int]] | |||
|  |         outputs = ngramDecoder.get_outputs()    # List[List[int]] | |||
|  |         times = ngramDecoder.get_times()        # List[List[int]] | |||
|  |          | |||
|  |         # 为每个候选添加详细信息 | |||
|  |         for i, (inp, out, time_seq) in enumerate(zip(inputs, outputs, times)): | |||
|  |             logging.info(f"Candidate {i}:") | |||
|  |             logging.info(f"  Sentence: {nbest_out[i][0]}") | |||
|  |             logging.info(f"  Token IDs: {out}") | |||
|  |             logging.info(f"  Timestamps (frames): {time_seq}") | |||
|  |              | |||
|  |             # 转换为可读格式(需要词表) | |||
|  |             if symbol_table is not None: | |||
|  |                 tokens = [symbol_table[tid] for tid in out] | |||
|  |                 logging.info(f"  Tokens: {tokens}") | |||
|  |                  | |||
|  |                 # 生成详细的时间对齐 | |||
|  |                 for token, start_frame in zip(tokens, time_seq): | |||
|  |                     time_ms = start_frame * 10  # 假设每帧10ms | |||
|  |                     logging.info(f"    {token} @ {time_ms}ms (frame {start_frame})") | |||
|  |      | |||
|  |     except AttributeError: | |||
|  |         logging.warning("Enhanced decoder methods not available. Please recompile with updated bindings.") | |||
|  | ``` | |||
|  | 
 | |||
|  | ## 示例输出
 | |||
|  | 
 | |||
|  | 使用增强接口后,你可以获得: | |||
|  | 
 | |||
|  | ``` | |||
|  | Candidate 0: | |||
|  |   Sentence: hello world | |||
|  |   Token IDs: [15, 8, 12, 12, 15, 0, 23, 15, 18, 12, 4] | |||
|  |   Timestamps (frames): [5, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66] | |||
|  |   Tokens: ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'] | |||
|  |     h @ 50ms (frame 5) | |||
|  |     e @ 120ms (frame 12) | |||
|  |     l @ 180ms (frame 18) | |||
|  |     l @ 240ms (frame 24) | |||
|  |     o @ 300ms (frame 30) | |||
|  |       @ 360ms (frame 36) | |||
|  |     w @ 420ms (frame 42) | |||
|  |     o @ 480ms (frame 48) | |||
|  |     r @ 540ms (frame 54) | |||
|  |     l @ 600ms (frame 60) | |||
|  |     d @ 660ms (frame 66) | |||
|  | ``` | |||
|  | 
 | |||
|  | ## 注意事项
 | |||
|  | 
 | |||
|  | 1. **Token vs 音素**:这个系统使用的是字符级别(character-level)的建模,不是音素 | |||
|  | 2. **时间戳精度**:时间戳是帧级别的,需要乘以帧长(通常10ms)转换为时间 | |||
|  | 3. **CTC 特性**:由于 blank frame skipping,时间戳可能不连续 | |||
|  | 4. **N-best**:每个候选都有独立的时间戳序列 | |||
|  | 
 | |||
|  | ## 参考
 | |||
|  | 
 | |||
|  | - C++ 接口:`runtime/core/decoder/search_interface.h` | |||
|  | - WFST 解码实现:`runtime/core/decoder/ctc_wfst_beam_search.cc` | |||
|  | - 时间戳生成:`ConvertToInputs()` 方法中的 `decoded_frames_mapping_` |