Files
b2txt25/language_model/runtime/core/decoder/params.h
2025-07-02 12:18:09 -07:00

122 lines
4.6 KiB
C++

// Copyright 2021 Mobvoi Inc. All Rights Reserved.
// Author: binbinzhang@mobvoi.com (Binbin Zhang)
// di.wu@mobvoi.com (Di Wu)
#ifndef DECODER_PARAMS_H_
#define DECODER_PARAMS_H_
#include <memory>
#include "decoder/torch_asr_decoder.h"
#include "decoder/torch_asr_model.h"
#include "frontend/feature_pipeline.h"
#include "utils/flags.h"
// TorchAsrModel flags
DEFINE_int32(num_threads, 1, "num threads for GEMM");
DEFINE_string(model_path, "", "pytorch exported model path");
// FeaturePipelineConfig flags
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
DEFINE_int32(sample_rate, 16000, "sample rate for audio");
// TLG fst
DEFINE_string(fst_path, "", "TLG fst path");
// DecodeOptions flags
DEFINE_int32(chunk_size, 16, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight, 0.0,
"ctc weight when combining ctc score and rescoring score");
DEFINE_double(rescoring_weight, 1.0,
"rescoring weight when combining ctc score and rescoring score");
DEFINE_double(reverse_weight, 0.0,
"used for bitransformer rescoring. it must be 0.0 if decoder is"
"conventional transformer decoder, and only reverse_weight > 0.0"
"dose the right to left decoder will be calculated and used");
DEFINE_int32(max_active, 7000, "max active states in ctc wfst search");
DEFINE_int32(min_active, 200, "min active states in ctc wfst search");
DEFINE_double(beam, 16.0, "beam in ctc wfst search");
DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search");
DEFINE_double(blank_skip_thresh, 1.0,
"blank skip thresh for ctc wfst search, 1.0 means no skip");
DEFINE_double(length_penalty, 0.0, "length penalty ctc wfst search, will not"
"apply on self-loop arc, for balancing the del/ins ratio, "
"suggest set to -3.0");
DEFINE_int32(nbest, 10, "nbest for ctc wfst search");
// SymbolTable flags
DEFINE_string(dict_path, "",
"dict symbol table path, it's same as unit_path when we don't "
"use LM in decoding");
DEFINE_string(
unit_path, "",
"e2e model unit symbol table, used for get timestamp of the result");
namespace wenet {
std::shared_ptr<FeaturePipelineConfig> InitFeaturePipelineConfigFromFlags() {
auto feature_config = std::make_shared<FeaturePipelineConfig>(
FLAGS_num_bins, FLAGS_sample_rate);
return feature_config;
}
std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
auto decode_config = std::make_shared<DecodeOptions>();
decode_config->chunk_size = FLAGS_chunk_size;
decode_config->num_left_chunks = FLAGS_num_left_chunks;
decode_config->ctc_weight = FLAGS_ctc_weight;
decode_config->reverse_weight = FLAGS_reverse_weight;
decode_config->rescoring_weight = FLAGS_rescoring_weight;
decode_config->ctc_wfst_search_opts.max_active = FLAGS_max_active;
decode_config->ctc_wfst_search_opts.min_active = FLAGS_min_active;
decode_config->ctc_wfst_search_opts.beam = FLAGS_beam;
decode_config->ctc_wfst_search_opts.lattice_beam = FLAGS_lattice_beam;
decode_config->ctc_wfst_search_opts.acoustic_scale = FLAGS_acoustic_scale;
decode_config->ctc_wfst_search_opts.blank_skip_thresh =
FLAGS_blank_skip_thresh;
decode_config->ctc_wfst_search_opts.nbest = FLAGS_nbest;
return decode_config;
}
std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
auto resource = std::make_shared<DecodeResource>();
LOG(INFO) << "Reading model " << FLAGS_model_path;
auto model = std::make_shared<TorchAsrModel>();
model->Read(FLAGS_model_path, FLAGS_num_threads);
resource->model = model;
std::shared_ptr<fst::Fst<fst::StdArc>> fst = nullptr;
if (!FLAGS_fst_path.empty()) {
LOG(INFO) << "Reading fst " << FLAGS_fst_path;
fst.reset(fst::Fst<fst::StdArc>::Read(FLAGS_fst_path));
CHECK(fst != nullptr);
}
resource->fst = fst;
LOG(INFO) << "Reading symbol table " << FLAGS_dict_path;
auto symbol_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(FLAGS_dict_path));
resource->symbol_table = symbol_table;
std::shared_ptr<fst::SymbolTable> unit_table = nullptr;
if (!FLAGS_unit_path.empty()) {
LOG(INFO) << "Reading unit table " << FLAGS_unit_path;
unit_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(FLAGS_unit_path));
CHECK(unit_table != nullptr);
} else if (fst == nullptr) {
LOG(INFO) << "Use symbol table as unit table";
unit_table = symbol_table;
}
resource->unit_table = unit_table;
return resource;
}
} // namespace wenet
#endif // DECODER_PARAMS_H_