Files
b2txt25/language_model/PATCH_lm_decoder_enhanced.md
2025-10-12 09:11:32 +08:00

5.1 KiB
Raw Blame History

增强 lm_decoder Python 绑定以支持时间戳提取

问题

当前 lm_decoder Python 绑定只暴露了句子和分数,没有暴露:

  • Token 序列inputs/outputs
  • 时间戳信息times
  • 详细的似然度信息

解决方案

步骤 1: 修改 brain_speech_decoder.h

BrainSpeechDecoder 类中添加公有访问方法:

// 在 class BrainSpeechDecoder 的 public 部分添加

const std::vector<std::vector<int>>& GetInputs() const {
    if (searcher_ == nullptr) {
        static std::vector<std::vector<int>> empty;
        return empty;
    }
    return searcher_->Inputs();
}

const std::vector<std::vector<int>>& GetOutputs() const {
    if (searcher_ == nullptr) {
        static std::vector<std::vector<int>> empty;
        return empty;
    }
    return searcher_->Outputs();
}

const std::vector<std::vector<int>>& GetTimes() const {
    if (searcher_ == nullptr) {
        static std::vector<std::vector<int>> empty;
        return empty;
    }
    return searcher_->Times();
}

const std::vector<std::pair<float, float>>& GetLikelihood() const {
    if (searcher_ == nullptr) {
        static std::vector<std::pair<float, float>> empty;
        return empty;
    }
    return searcher_->Likelihood();
}

步骤 2: 修改 lm_decoder.cc Python 绑定

PYBIND11_MODULE 中添加新的方法绑定:

py::class_<BrainSpeechDecoder>(m, "BrainSpeechDecoder")
    .def(py::init<std::shared_ptr<DecodeResource>, std::shared_ptr<DecodeOptions> >())
    .def("SetOpt", &BrainSpeechDecoder::SetOpt)
    .def("Decode", &BrainSpeechDecoder::Decode)
    .def("Rescore", &BrainSpeechDecoder::Rescore)
    .def("Reset", &BrainSpeechDecoder::Reset)
    .def("FinishDecoding", &BrainSpeechDecoder::FinishDecoding)
    .def("DecodedSomething", &BrainSpeechDecoder::DecodedSomething)
    .def("result", &BrainSpeechDecoder::result)
    // 新增方法
    .def("get_inputs", &BrainSpeechDecoder::GetInputs,
         "Get input token sequences for N-best hypotheses")
    .def("get_outputs", &BrainSpeechDecoder::GetOutputs,
         "Get output token sequences for N-best hypotheses")
    .def("get_times", &BrainSpeechDecoder::GetTimes,
         "Get timestamps for each token in N-best hypotheses")
    .def("get_likelihood", &BrainSpeechDecoder::GetLikelihood,
         "Get (acoustic_score, lm_score) pairs for N-best hypotheses");

步骤 3: 重新编译

cd language_model/runtime/server/x86
mkdir -p build && cd build
cmake ..
make -j$(nproc)

步骤 4: 使用增强接口

修改 language-model-standalone.py

# 在 Finalize 阶段获取详细信息
if nbest > 1:
    # 获取基本结果
    nbest_out = []
    for d in ngramDecoder.result():
        nbest_out.append([d.sentence, d.ac_score, d.lm_score])
    
    # 获取时间戳和token序列新增
    try:
        inputs = ngramDecoder.get_inputs()      # List[List[int]]
        outputs = ngramDecoder.get_outputs()    # List[List[int]]
        times = ngramDecoder.get_times()        # List[List[int]]
        
        # 为每个候选添加详细信息
        for i, (inp, out, time_seq) in enumerate(zip(inputs, outputs, times)):
            logging.info(f"Candidate {i}:")
            logging.info(f"  Sentence: {nbest_out[i][0]}")
            logging.info(f"  Token IDs: {out}")
            logging.info(f"  Timestamps (frames): {time_seq}")
            
            # 转换为可读格式(需要词表)
            if symbol_table is not None:
                tokens = [symbol_table[tid] for tid in out]
                logging.info(f"  Tokens: {tokens}")
                
                # 生成详细的时间对齐
                for token, start_frame in zip(tokens, time_seq):
                    time_ms = start_frame * 10  # 假设每帧10ms
                    logging.info(f"    {token} @ {time_ms}ms (frame {start_frame})")
    
    except AttributeError:
        logging.warning("Enhanced decoder methods not available. Please recompile with updated bindings.")

示例输出

使用增强接口后,你可以获得:

Candidate 0:
  Sentence: hello world
  Token IDs: [15, 8, 12, 12, 15, 0, 23, 15, 18, 12, 4]
  Timestamps (frames): [5, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66]
  Tokens: ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd']
    h @ 50ms (frame 5)
    e @ 120ms (frame 12)
    l @ 180ms (frame 18)
    l @ 240ms (frame 24)
    o @ 300ms (frame 30)
      @ 360ms (frame 36)
    w @ 420ms (frame 42)
    o @ 480ms (frame 48)
    r @ 540ms (frame 54)
    l @ 600ms (frame 60)
    d @ 660ms (frame 66)

注意事项

  1. Token vs 音素这个系统使用的是字符级别character-level的建模不是音素
  2. 时间戳精度时间戳是帧级别的需要乘以帧长通常10ms转换为时间
  3. CTC 特性:由于 blank frame skipping时间戳可能不连续
  4. N-best:每个候选都有独立的时间戳序列

参考

  • C++ 接口:runtime/core/decoder/search_interface.h
  • WFST 解码实现:runtime/core/decoder/ctc_wfst_beam_search.cc
  • 时间戳生成:ConvertToInputs() 方法中的 decoded_frames_mapping_