备份1
This commit is contained in:
158
language_model/PATCH_lm_decoder_enhanced.md
Normal file
158
language_model/PATCH_lm_decoder_enhanced.md
Normal file
@@ -0,0 +1,158 @@
|
||||
# 增强 lm_decoder Python 绑定以支持时间戳提取
|
||||
|
||||
## 问题
|
||||
当前 `lm_decoder` Python 绑定只暴露了句子和分数,没有暴露:
|
||||
- Token 序列(inputs/outputs)
|
||||
- 时间戳信息(times)
|
||||
- 详细的似然度信息
|
||||
|
||||
## 解决方案
|
||||
|
||||
### 步骤 1: 修改 brain_speech_decoder.h
|
||||
|
||||
在 `BrainSpeechDecoder` 类中添加公有访问方法:
|
||||
|
||||
```cpp
|
||||
// 在 class BrainSpeechDecoder 的 public 部分添加
|
||||
|
||||
const std::vector<std::vector<int>>& GetInputs() const {
|
||||
if (searcher_ == nullptr) {
|
||||
static std::vector<std::vector<int>> empty;
|
||||
return empty;
|
||||
}
|
||||
return searcher_->Inputs();
|
||||
}
|
||||
|
||||
const std::vector<std::vector<int>>& GetOutputs() const {
|
||||
if (searcher_ == nullptr) {
|
||||
static std::vector<std::vector<int>> empty;
|
||||
return empty;
|
||||
}
|
||||
return searcher_->Outputs();
|
||||
}
|
||||
|
||||
const std::vector<std::vector<int>>& GetTimes() const {
|
||||
if (searcher_ == nullptr) {
|
||||
static std::vector<std::vector<int>> empty;
|
||||
return empty;
|
||||
}
|
||||
return searcher_->Times();
|
||||
}
|
||||
|
||||
const std::vector<std::pair<float, float>>& GetLikelihood() const {
|
||||
if (searcher_ == nullptr) {
|
||||
static std::vector<std::pair<float, float>> empty;
|
||||
return empty;
|
||||
}
|
||||
return searcher_->Likelihood();
|
||||
}
|
||||
```
|
||||
|
||||
### 步骤 2: 修改 lm_decoder.cc Python 绑定
|
||||
|
||||
在 `PYBIND11_MODULE` 中添加新的方法绑定:
|
||||
|
||||
```cpp
|
||||
py::class_<BrainSpeechDecoder>(m, "BrainSpeechDecoder")
|
||||
.def(py::init<std::shared_ptr<DecodeResource>, std::shared_ptr<DecodeOptions> >())
|
||||
.def("SetOpt", &BrainSpeechDecoder::SetOpt)
|
||||
.def("Decode", &BrainSpeechDecoder::Decode)
|
||||
.def("Rescore", &BrainSpeechDecoder::Rescore)
|
||||
.def("Reset", &BrainSpeechDecoder::Reset)
|
||||
.def("FinishDecoding", &BrainSpeechDecoder::FinishDecoding)
|
||||
.def("DecodedSomething", &BrainSpeechDecoder::DecodedSomething)
|
||||
.def("result", &BrainSpeechDecoder::result)
|
||||
// 新增方法
|
||||
.def("get_inputs", &BrainSpeechDecoder::GetInputs,
|
||||
"Get input token sequences for N-best hypotheses")
|
||||
.def("get_outputs", &BrainSpeechDecoder::GetOutputs,
|
||||
"Get output token sequences for N-best hypotheses")
|
||||
.def("get_times", &BrainSpeechDecoder::GetTimes,
|
||||
"Get timestamps for each token in N-best hypotheses")
|
||||
.def("get_likelihood", &BrainSpeechDecoder::GetLikelihood,
|
||||
"Get (acoustic_score, lm_score) pairs for N-best hypotheses");
|
||||
```
|
||||
|
||||
### 步骤 3: 重新编译
|
||||
|
||||
```bash
|
||||
cd language_model/runtime/server/x86
|
||||
mkdir -p build && cd build
|
||||
cmake ..
|
||||
make -j$(nproc)
|
||||
```
|
||||
|
||||
### 步骤 4: 使用增强接口
|
||||
|
||||
修改 `language-model-standalone.py`:
|
||||
|
||||
```python
|
||||
# 在 Finalize 阶段获取详细信息
|
||||
if nbest > 1:
|
||||
# 获取基本结果
|
||||
nbest_out = []
|
||||
for d in ngramDecoder.result():
|
||||
nbest_out.append([d.sentence, d.ac_score, d.lm_score])
|
||||
|
||||
# 获取时间戳和token序列(新增)
|
||||
try:
|
||||
inputs = ngramDecoder.get_inputs() # List[List[int]]
|
||||
outputs = ngramDecoder.get_outputs() # List[List[int]]
|
||||
times = ngramDecoder.get_times() # List[List[int]]
|
||||
|
||||
# 为每个候选添加详细信息
|
||||
for i, (inp, out, time_seq) in enumerate(zip(inputs, outputs, times)):
|
||||
logging.info(f"Candidate {i}:")
|
||||
logging.info(f" Sentence: {nbest_out[i][0]}")
|
||||
logging.info(f" Token IDs: {out}")
|
||||
logging.info(f" Timestamps (frames): {time_seq}")
|
||||
|
||||
# 转换为可读格式(需要词表)
|
||||
if symbol_table is not None:
|
||||
tokens = [symbol_table[tid] for tid in out]
|
||||
logging.info(f" Tokens: {tokens}")
|
||||
|
||||
# 生成详细的时间对齐
|
||||
for token, start_frame in zip(tokens, time_seq):
|
||||
time_ms = start_frame * 10 # 假设每帧10ms
|
||||
logging.info(f" {token} @ {time_ms}ms (frame {start_frame})")
|
||||
|
||||
except AttributeError:
|
||||
logging.warning("Enhanced decoder methods not available. Please recompile with updated bindings.")
|
||||
```
|
||||
|
||||
## 示例输出
|
||||
|
||||
使用增强接口后,你可以获得:
|
||||
|
||||
```
|
||||
Candidate 0:
|
||||
Sentence: hello world
|
||||
Token IDs: [15, 8, 12, 12, 15, 0, 23, 15, 18, 12, 4]
|
||||
Timestamps (frames): [5, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66]
|
||||
Tokens: ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd']
|
||||
h @ 50ms (frame 5)
|
||||
e @ 120ms (frame 12)
|
||||
l @ 180ms (frame 18)
|
||||
l @ 240ms (frame 24)
|
||||
o @ 300ms (frame 30)
|
||||
@ 360ms (frame 36)
|
||||
w @ 420ms (frame 42)
|
||||
o @ 480ms (frame 48)
|
||||
r @ 540ms (frame 54)
|
||||
l @ 600ms (frame 60)
|
||||
d @ 660ms (frame 66)
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **Token vs 音素**:这个系统使用的是字符级别(character-level)的建模,不是音素
|
||||
2. **时间戳精度**:时间戳是帧级别的,需要乘以帧长(通常10ms)转换为时间
|
||||
3. **CTC 特性**:由于 blank frame skipping,时间戳可能不连续
|
||||
4. **N-best**:每个候选都有独立的时间戳序列
|
||||
|
||||
## 参考
|
||||
|
||||
- C++ 接口:`runtime/core/decoder/search_interface.h`
|
||||
- WFST 解码实现:`runtime/core/decoder/ctc_wfst_beam_search.cc`
|
||||
- 时间戳生成:`ConvertToInputs()` 方法中的 `decoded_frames_mapping_`
|
Reference in New Issue
Block a user