competition update

This commit is contained in:
nckcard
2025-07-02 12:18:09 -07:00
parent 9e17716a4a
commit 77dbcf868f
2615 changed files with 1648116 additions and 125 deletions

View File

@@ -0,0 +1,85 @@
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "grpc/grpc_client.h"
#include "utils/log.h"
namespace wenet {
using grpc::Channel;
using grpc::ClientContext;
using grpc::ClientReaderWriter;
using grpc::Status;
using wenet::Request;
using wenet::Response;
GrpcClient::GrpcClient(const std::string& host, int port, int nbest,
bool continuous_decoding)
: host_(host),
port_(port),
nbest_(nbest),
continuous_decoding_(continuous_decoding) {
Connect();
t_.reset(new std::thread(&GrpcClient::ReadLoopFunc, this));
}
void GrpcClient::Connect() {
channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_),
grpc::InsecureChannelCredentials());
stub_ = ASR::NewStub(channel_);
context_ = std::make_shared<ClientContext>();
stream_ = stub_->Recognize(context_.get());
request_ = std::make_shared<Request>();
response_ = std::make_shared<Response>();
request_->mutable_decode_config()->set_nbest_config(nbest_);
request_->mutable_decode_config()->set_continuous_decoding_config(
continuous_decoding_);
stream_->Write(*request_);
}
void GrpcClient::SendBinaryData(const void* data, size_t size) {
const int16_t* pdata = reinterpret_cast<const int16_t*>(data);
request_->set_audio_data(pdata, size);
stream_->Write(*request_);
}
void GrpcClient::ReadLoopFunc() {
try {
while (stream_->Read(response_.get())) {
for (int i = 0; i < response_->nbest_size(); i++) {
// you can also traverse wordpieces like demonstrated above
LOG(INFO) << i + 1 << "best " << response_->nbest(i).sentence();
}
if (response_->status() != Response_Status_ok) {
break;
}
if (response_->type() == Response_Type_speech_end) {
done_ = true;
break;
}
}
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
}
void GrpcClient::Join() {
stream_->WritesDone();
t_->join();
Status status = stream_->Finish();
if (!status.ok()) {
LOG(INFO) << "Recognize rpc failed.";
}
}
} // namespace wenet

View File

@@ -0,0 +1,70 @@
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_GRPC_CLIENT_H_
#define GRPC_GRPC_CLIENT_H_
#include <grpc/grpc.h>
#include <grpcpp/channel.h>
#include <grpcpp/client_context.h>
#include <grpcpp/create_channel.h>
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include "utils/utils.h"
#include "grpc/wenet.grpc.pb.h"
namespace wenet {
using grpc::Channel;
using grpc::ClientContext;
using grpc::ClientReaderWriter;
using wenet::ASR;
using wenet::Request;
using wenet::Response;
class GrpcClient {
public:
GrpcClient(const std::string& host, int port, int nbest,
bool continuous_decoding);
void SendBinaryData(const void* data, size_t size);
void ReadLoopFunc();
void Join();
bool done() const { return done_; }
private:
void Connect();
std::string host_;
int port_;
std::shared_ptr<Channel> channel_{nullptr};
std::unique_ptr<ASR::Stub> stub_{nullptr};
std::shared_ptr<ClientContext> context_{nullptr};
std::unique_ptr<ClientReaderWriter<Request, Response>> stream_{nullptr};
std::shared_ptr<Request> request_{nullptr};
std::shared_ptr<Response> response_{nullptr};
int nbest_ = 1;
bool continuous_decoding_ = false;
bool done_ = false;
std::unique_ptr<std::thread> t_{nullptr};
WENET_DISALLOW_COPY_AND_ASSIGN(GrpcClient);
};
} // namespace wenet
#endif // GRPC_GRPC_CLIENT_H_

View File

@@ -0,0 +1,185 @@
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "grpc/grpc_server.h"
namespace wenet {
using grpc::ServerReaderWriter;
using wenet::Request;
using wenet::Response;
GrpcConnectionHandler::GrpcConnectionHandler(
ServerReaderWriter<Response, Request>* stream,
std::shared_ptr<Request> request, std::shared_ptr<Response> response,
std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<fst::SymbolTable> symbol_table,
std::shared_ptr<TorchAsrModel> model,
std::shared_ptr<fst::Fst<fst::StdArc>> fst)
: stream_(std::move(stream)),
request_(std::move(request)),
response_(std::move(response)),
feature_config_(std::move(feature_config)),
decode_config_(std::move(decode_config)),
symbol_table_(std::move(symbol_table)),
model_(std::move(model)),
fst_(std::move(fst)) {}
void GrpcConnectionHandler::OnSpeechStart() {
LOG(INFO) << "Recieved speech start signal, start reading speech";
got_start_tag_ = true;
response_->set_status(Response::ok);
response_->set_type(Response::server_ready);
stream_->Write(*response_);
feature_pipeline_ = std::make_shared<FeaturePipeline>(*feature_config_);
decoder_ = std::make_shared<TorchAsrDecoder>(
feature_pipeline_, model_, symbol_table_, *decode_config_, fst_);
// Start decoder thread
decode_thread_ = std::make_shared<std::thread>(
&GrpcConnectionHandler::DecodeThreadFunc, this);
}
void GrpcConnectionHandler::OnSpeechEnd() {
LOG(INFO) << "Recieved speech end signal";
CHECK(feature_pipeline_ != nullptr);
feature_pipeline_->set_input_finished();
got_end_tag_ = true;
}
void GrpcConnectionHandler::OnPartialResult() {
LOG(INFO) << "Partial result";
response_->set_status(Response::ok);
response_->set_type(Response::partial_result);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnFinalResult() {
LOG(INFO) << "Final result";
response_->set_status(Response::ok);
response_->set_type(Response::final_result);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnFinish() {
// Send finish tag
response_->set_status(Response::ok);
response_->set_type(Response::speech_end);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnSpeechData() {
// Read binary PCM data
const int16_t* pdata =
reinterpret_cast<const int16_t*>(request_->audio_data().c_str());
int num_samples = request_->audio_data().length() / sizeof(int16_t);
std::vector<float> pcm_data(num_samples);
for (int i = 0; i < num_samples; i++) {
pcm_data[i] = static_cast<float>(*pdata);
pdata++;
}
VLOG(2) << "Recieved " << num_samples << " samples";
CHECK(feature_pipeline_ != nullptr);
CHECK(decoder_ != nullptr);
feature_pipeline_->AcceptWaveform(pcm_data);
}
void GrpcConnectionHandler::SerializeResult(bool finish) {
for (const DecodeResult& path : decoder_->result()) {
Response_OneBest* one_best_ = response_->add_nbest();
one_best_->set_sentence(path.sentence);
if (finish) {
for (const WordPiece& word_piece : path.word_pieces) {
Response_OnePiece* one_piece_ = one_best_->add_wordpieces();
one_piece_->set_word(word_piece.word);
one_piece_->set_start(word_piece.start);
one_piece_->set_end(word_piece.end);
}
}
if (response_->nbest_size() == nbest_) {
break;
}
}
return;
}
void GrpcConnectionHandler::DecodeThreadFunc() {
while (true) {
DecodeState state = decoder_->Decode();
response_->clear_status();
response_->clear_type();
response_->clear_nbest();
if (state == DecodeState::kEndFeats) {
decoder_->Rescoring();
SerializeResult(true);
OnFinalResult();
OnFinish();
stop_recognition_ = true;
break;
} else if (state == DecodeState::kEndpoint) {
decoder_->Rescoring();
SerializeResult(true);
OnFinalResult();
// If it's not continuous decoidng, continue to do next recognition
// otherwise stop the recognition
if (continuous_decoding_) {
decoder_->ResetContinuousDecoding();
} else {
OnFinish();
stop_recognition_ = true;
break;
}
} else {
if (decoder_->DecodedSomething()) {
SerializeResult(false);
OnPartialResult();
}
}
}
}
void GrpcConnectionHandler::operator()() {
try {
while (stream_->Read(request_.get())) {
if (!got_start_tag_) {
nbest_ = request_->decode_config().nbest_config();
continuous_decoding_ =
request_->decode_config().continuous_decoding_config();
OnSpeechStart();
} else {
OnSpeechData();
}
}
OnSpeechEnd();
LOG(INFO) << "Read all pcm data, wait for decoding thread";
if (decode_thread_ != nullptr) {
decode_thread_->join();
}
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
}
Status GrpcServer::Recognize(ServerContext* context,
ServerReaderWriter<Response, Request>* stream) {
LOG(INFO) << "Get Recognize request" << std::endl;
auto request = std::make_shared<Request>();
auto response = std::make_shared<Response>();
GrpcConnectionHandler handler(stream, request, response, feature_config_,
decode_config_, symbol_table_, model_, fst_);
std::thread t(std::move(handler));
t.join();
return Status::OK;
}
} // namespace wenet

View File

@@ -0,0 +1,109 @@
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_GRPC_SERVER_H_
#define GRPC_GRPC_SERVER_H_
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "decoder/torch_asr_decoder.h"
#include "decoder/torch_asr_model.h"
#include "frontend/feature_pipeline.h"
#include "utils/log.h"
#include "grpc/wenet.grpc.pb.h"
namespace wenet {
using grpc::ServerContext;
using grpc::ServerReaderWriter;
using grpc::Status;
using wenet::ASR;
using wenet::Request;
using wenet::Response;
class GrpcConnectionHandler {
public:
GrpcConnectionHandler(ServerReaderWriter<Response, Request> *stream,
std::shared_ptr<Request> request,
std::shared_ptr<Response> response,
std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<fst::SymbolTable> symbol_table,
std::shared_ptr<TorchAsrModel> model,
std::shared_ptr<fst::Fst<fst::StdArc>> fst);
void operator()();
private:
void OnSpeechStart();
void OnSpeechEnd();
void OnFinish();
void OnSpeechData();
void OnPartialResult();
void OnFinalResult();
void DecodeThreadFunc();
void SerializeResult(bool finish);
bool continuous_decoding_ = false;
int nbest_ = 1;
ServerReaderWriter<Response, Request> *stream_;
std::shared_ptr<Request> request_;
std::shared_ptr<Response> response_;
std::shared_ptr<FeaturePipelineConfig> feature_config_;
std::shared_ptr<DecodeOptions> decode_config_;
std::shared_ptr<fst::SymbolTable> symbol_table_;
std::shared_ptr<TorchAsrModel> model_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
bool got_start_tag_ = false;
bool got_end_tag_ = false;
// When endpoint is detected, stop recognition, and stop receiving data.
bool stop_recognition_ = false;
std::shared_ptr<FeaturePipeline> feature_pipeline_ = nullptr;
std::shared_ptr<TorchAsrDecoder> decoder_ = nullptr;
std::shared_ptr<std::thread> decode_thread_ = nullptr;
};
class GrpcServer final : public ASR::Service {
public:
GrpcServer(std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<fst::SymbolTable> symbol_table,
std::shared_ptr<TorchAsrModel> model,
std::shared_ptr<fst::Fst<fst::StdArc>> fst)
: feature_config_(std::move(feature_config)),
decode_config_(std::move(decode_config)),
symbol_table_(std::move(symbol_table)),
model_(std::move(model)),
fst_(std::move(fst)) {}
Status Recognize(ServerContext *context,
ServerReaderWriter<Response, Request> *reader) override;
private:
std::shared_ptr<FeaturePipelineConfig> feature_config_;
std::shared_ptr<DecodeOptions> decode_config_;
std::shared_ptr<fst::SymbolTable> symbol_table_;
std::shared_ptr<TorchAsrModel> model_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
DISALLOW_COPY_AND_ASSIGN(GrpcServer);
};
} // namespace wenet
#endif // GRPC_GRPC_SERVER_H_

View File

@@ -0,0 +1,66 @@
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
option java_package = "ex.grpc";
option objc_class_prefix = "wenet";
package wenet;
service ASR {
rpc Recognize (stream Request) returns (stream Response) {}
}
message Request {
message DecodeConfig {
int32 nbest_config = 1;
bool continuous_decoding_config = 2;
}
oneof RequestPayload {
DecodeConfig decode_config = 1;
bytes audio_data = 2;
}
}
message Response {
message OneBest {
string sentence = 1;
repeated OnePiece wordpieces = 2;
}
message OnePiece {
string word = 1;
int32 start = 2;
int32 end = 3;
}
enum Status {
ok = 0;
failed = 1;
}
enum Type {
server_ready = 0;
partial_result = 1;
final_result = 2;
speech_end = 3;
}
Status status = 1;
Type type = 2;
repeated OneBest nbest = 3;
}