competition update
This commit is contained in:
85
language_model/runtime/core/grpc/grpc_client.cc
Normal file
85
language_model/runtime/core/grpc/grpc_client.cc
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "grpc/grpc_client.h"
|
||||
|
||||
#include "utils/log.h"
|
||||
|
||||
namespace wenet {
|
||||
using grpc::Channel;
|
||||
using grpc::ClientContext;
|
||||
using grpc::ClientReaderWriter;
|
||||
using grpc::Status;
|
||||
using wenet::Request;
|
||||
using wenet::Response;
|
||||
|
||||
GrpcClient::GrpcClient(const std::string& host, int port, int nbest,
|
||||
bool continuous_decoding)
|
||||
: host_(host),
|
||||
port_(port),
|
||||
nbest_(nbest),
|
||||
continuous_decoding_(continuous_decoding) {
|
||||
Connect();
|
||||
t_.reset(new std::thread(&GrpcClient::ReadLoopFunc, this));
|
||||
}
|
||||
|
||||
void GrpcClient::Connect() {
|
||||
channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_),
|
||||
grpc::InsecureChannelCredentials());
|
||||
stub_ = ASR::NewStub(channel_);
|
||||
context_ = std::make_shared<ClientContext>();
|
||||
stream_ = stub_->Recognize(context_.get());
|
||||
request_ = std::make_shared<Request>();
|
||||
response_ = std::make_shared<Response>();
|
||||
request_->mutable_decode_config()->set_nbest_config(nbest_);
|
||||
request_->mutable_decode_config()->set_continuous_decoding_config(
|
||||
continuous_decoding_);
|
||||
stream_->Write(*request_);
|
||||
}
|
||||
|
||||
void GrpcClient::SendBinaryData(const void* data, size_t size) {
|
||||
const int16_t* pdata = reinterpret_cast<const int16_t*>(data);
|
||||
request_->set_audio_data(pdata, size);
|
||||
stream_->Write(*request_);
|
||||
}
|
||||
|
||||
void GrpcClient::ReadLoopFunc() {
|
||||
try {
|
||||
while (stream_->Read(response_.get())) {
|
||||
for (int i = 0; i < response_->nbest_size(); i++) {
|
||||
// you can also traverse wordpieces like demonstrated above
|
||||
LOG(INFO) << i + 1 << "best " << response_->nbest(i).sentence();
|
||||
}
|
||||
if (response_->status() != Response_Status_ok) {
|
||||
break;
|
||||
}
|
||||
if (response_->type() == Response_Type_speech_end) {
|
||||
done_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (std::exception const& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
void GrpcClient::Join() {
|
||||
stream_->WritesDone();
|
||||
t_->join();
|
||||
Status status = stream_->Finish();
|
||||
if (!status.ok()) {
|
||||
LOG(INFO) << "Recognize rpc failed.";
|
||||
}
|
||||
}
|
||||
} // namespace wenet
|
70
language_model/runtime/core/grpc/grpc_client.h
Normal file
70
language_model/runtime/core/grpc/grpc_client.h
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GRPC_GRPC_CLIENT_H_
|
||||
#define GRPC_GRPC_CLIENT_H_
|
||||
|
||||
#include <grpc/grpc.h>
|
||||
#include <grpcpp/channel.h>
|
||||
#include <grpcpp/client_context.h>
|
||||
#include <grpcpp/create_channel.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include "utils/utils.h"
|
||||
#include "grpc/wenet.grpc.pb.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
using grpc::Channel;
|
||||
using grpc::ClientContext;
|
||||
using grpc::ClientReaderWriter;
|
||||
using wenet::ASR;
|
||||
using wenet::Request;
|
||||
using wenet::Response;
|
||||
|
||||
class GrpcClient {
|
||||
public:
|
||||
GrpcClient(const std::string& host, int port, int nbest,
|
||||
bool continuous_decoding);
|
||||
|
||||
void SendBinaryData(const void* data, size_t size);
|
||||
void ReadLoopFunc();
|
||||
void Join();
|
||||
bool done() const { return done_; }
|
||||
|
||||
private:
|
||||
void Connect();
|
||||
std::string host_;
|
||||
int port_;
|
||||
std::shared_ptr<Channel> channel_{nullptr};
|
||||
std::unique_ptr<ASR::Stub> stub_{nullptr};
|
||||
std::shared_ptr<ClientContext> context_{nullptr};
|
||||
std::unique_ptr<ClientReaderWriter<Request, Response>> stream_{nullptr};
|
||||
std::shared_ptr<Request> request_{nullptr};
|
||||
std::shared_ptr<Response> response_{nullptr};
|
||||
int nbest_ = 1;
|
||||
bool continuous_decoding_ = false;
|
||||
bool done_ = false;
|
||||
std::unique_ptr<std::thread> t_{nullptr};
|
||||
|
||||
WENET_DISALLOW_COPY_AND_ASSIGN(GrpcClient);
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // GRPC_GRPC_CLIENT_H_
|
185
language_model/runtime/core/grpc/grpc_server.cc
Normal file
185
language_model/runtime/core/grpc/grpc_server.cc
Normal file
@@ -0,0 +1,185 @@
|
||||
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "grpc/grpc_server.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
using grpc::ServerReaderWriter;
|
||||
using wenet::Request;
|
||||
using wenet::Response;
|
||||
|
||||
GrpcConnectionHandler::GrpcConnectionHandler(
|
||||
ServerReaderWriter<Response, Request>* stream,
|
||||
std::shared_ptr<Request> request, std::shared_ptr<Response> response,
|
||||
std::shared_ptr<FeaturePipelineConfig> feature_config,
|
||||
std::shared_ptr<DecodeOptions> decode_config,
|
||||
std::shared_ptr<fst::SymbolTable> symbol_table,
|
||||
std::shared_ptr<TorchAsrModel> model,
|
||||
std::shared_ptr<fst::Fst<fst::StdArc>> fst)
|
||||
: stream_(std::move(stream)),
|
||||
request_(std::move(request)),
|
||||
response_(std::move(response)),
|
||||
feature_config_(std::move(feature_config)),
|
||||
decode_config_(std::move(decode_config)),
|
||||
symbol_table_(std::move(symbol_table)),
|
||||
model_(std::move(model)),
|
||||
fst_(std::move(fst)) {}
|
||||
|
||||
void GrpcConnectionHandler::OnSpeechStart() {
|
||||
LOG(INFO) << "Recieved speech start signal, start reading speech";
|
||||
got_start_tag_ = true;
|
||||
response_->set_status(Response::ok);
|
||||
response_->set_type(Response::server_ready);
|
||||
stream_->Write(*response_);
|
||||
feature_pipeline_ = std::make_shared<FeaturePipeline>(*feature_config_);
|
||||
decoder_ = std::make_shared<TorchAsrDecoder>(
|
||||
feature_pipeline_, model_, symbol_table_, *decode_config_, fst_);
|
||||
// Start decoder thread
|
||||
decode_thread_ = std::make_shared<std::thread>(
|
||||
&GrpcConnectionHandler::DecodeThreadFunc, this);
|
||||
}
|
||||
|
||||
void GrpcConnectionHandler::OnSpeechEnd() {
|
||||
LOG(INFO) << "Recieved speech end signal";
|
||||
CHECK(feature_pipeline_ != nullptr);
|
||||
feature_pipeline_->set_input_finished();
|
||||
got_end_tag_ = true;
|
||||
}
|
||||
|
||||
void GrpcConnectionHandler::OnPartialResult() {
|
||||
LOG(INFO) << "Partial result";
|
||||
response_->set_status(Response::ok);
|
||||
response_->set_type(Response::partial_result);
|
||||
stream_->Write(*response_);
|
||||
}
|
||||
|
||||
void GrpcConnectionHandler::OnFinalResult() {
|
||||
LOG(INFO) << "Final result";
|
||||
response_->set_status(Response::ok);
|
||||
response_->set_type(Response::final_result);
|
||||
stream_->Write(*response_);
|
||||
}
|
||||
|
||||
void GrpcConnectionHandler::OnFinish() {
|
||||
// Send finish tag
|
||||
response_->set_status(Response::ok);
|
||||
response_->set_type(Response::speech_end);
|
||||
stream_->Write(*response_);
|
||||
}
|
||||
|
||||
void GrpcConnectionHandler::OnSpeechData() {
|
||||
// Read binary PCM data
|
||||
const int16_t* pdata =
|
||||
reinterpret_cast<const int16_t*>(request_->audio_data().c_str());
|
||||
int num_samples = request_->audio_data().length() / sizeof(int16_t);
|
||||
std::vector<float> pcm_data(num_samples);
|
||||
for (int i = 0; i < num_samples; i++) {
|
||||
pcm_data[i] = static_cast<float>(*pdata);
|
||||
pdata++;
|
||||
}
|
||||
VLOG(2) << "Recieved " << num_samples << " samples";
|
||||
CHECK(feature_pipeline_ != nullptr);
|
||||
CHECK(decoder_ != nullptr);
|
||||
feature_pipeline_->AcceptWaveform(pcm_data);
|
||||
}
|
||||
|
||||
void GrpcConnectionHandler::SerializeResult(bool finish) {
|
||||
for (const DecodeResult& path : decoder_->result()) {
|
||||
Response_OneBest* one_best_ = response_->add_nbest();
|
||||
one_best_->set_sentence(path.sentence);
|
||||
if (finish) {
|
||||
for (const WordPiece& word_piece : path.word_pieces) {
|
||||
Response_OnePiece* one_piece_ = one_best_->add_wordpieces();
|
||||
one_piece_->set_word(word_piece.word);
|
||||
one_piece_->set_start(word_piece.start);
|
||||
one_piece_->set_end(word_piece.end);
|
||||
}
|
||||
}
|
||||
if (response_->nbest_size() == nbest_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void GrpcConnectionHandler::DecodeThreadFunc() {
|
||||
while (true) {
|
||||
DecodeState state = decoder_->Decode();
|
||||
response_->clear_status();
|
||||
response_->clear_type();
|
||||
response_->clear_nbest();
|
||||
if (state == DecodeState::kEndFeats) {
|
||||
decoder_->Rescoring();
|
||||
SerializeResult(true);
|
||||
OnFinalResult();
|
||||
OnFinish();
|
||||
stop_recognition_ = true;
|
||||
break;
|
||||
} else if (state == DecodeState::kEndpoint) {
|
||||
decoder_->Rescoring();
|
||||
SerializeResult(true);
|
||||
OnFinalResult();
|
||||
// If it's not continuous decoidng, continue to do next recognition
|
||||
// otherwise stop the recognition
|
||||
if (continuous_decoding_) {
|
||||
decoder_->ResetContinuousDecoding();
|
||||
} else {
|
||||
OnFinish();
|
||||
stop_recognition_ = true;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if (decoder_->DecodedSomething()) {
|
||||
SerializeResult(false);
|
||||
OnPartialResult();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GrpcConnectionHandler::operator()() {
|
||||
try {
|
||||
while (stream_->Read(request_.get())) {
|
||||
if (!got_start_tag_) {
|
||||
nbest_ = request_->decode_config().nbest_config();
|
||||
continuous_decoding_ =
|
||||
request_->decode_config().continuous_decoding_config();
|
||||
OnSpeechStart();
|
||||
} else {
|
||||
OnSpeechData();
|
||||
}
|
||||
}
|
||||
OnSpeechEnd();
|
||||
LOG(INFO) << "Read all pcm data, wait for decoding thread";
|
||||
if (decode_thread_ != nullptr) {
|
||||
decode_thread_->join();
|
||||
}
|
||||
} catch (std::exception const& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
Status GrpcServer::Recognize(ServerContext* context,
|
||||
ServerReaderWriter<Response, Request>* stream) {
|
||||
LOG(INFO) << "Get Recognize request" << std::endl;
|
||||
auto request = std::make_shared<Request>();
|
||||
auto response = std::make_shared<Response>();
|
||||
GrpcConnectionHandler handler(stream, request, response, feature_config_,
|
||||
decode_config_, symbol_table_, model_, fst_);
|
||||
std::thread t(std::move(handler));
|
||||
t.join();
|
||||
return Status::OK;
|
||||
}
|
||||
} // namespace wenet
|
109
language_model/runtime/core/grpc/grpc_server.h
Normal file
109
language_model/runtime/core/grpc/grpc_server.h
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GRPC_GRPC_SERVER_H_
|
||||
#define GRPC_GRPC_SERVER_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "decoder/torch_asr_decoder.h"
|
||||
#include "decoder/torch_asr_model.h"
|
||||
#include "frontend/feature_pipeline.h"
|
||||
#include "utils/log.h"
|
||||
|
||||
#include "grpc/wenet.grpc.pb.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
using grpc::ServerContext;
|
||||
using grpc::ServerReaderWriter;
|
||||
using grpc::Status;
|
||||
using wenet::ASR;
|
||||
using wenet::Request;
|
||||
using wenet::Response;
|
||||
|
||||
class GrpcConnectionHandler {
|
||||
public:
|
||||
GrpcConnectionHandler(ServerReaderWriter<Response, Request> *stream,
|
||||
std::shared_ptr<Request> request,
|
||||
std::shared_ptr<Response> response,
|
||||
std::shared_ptr<FeaturePipelineConfig> feature_config,
|
||||
std::shared_ptr<DecodeOptions> decode_config,
|
||||
std::shared_ptr<fst::SymbolTable> symbol_table,
|
||||
std::shared_ptr<TorchAsrModel> model,
|
||||
std::shared_ptr<fst::Fst<fst::StdArc>> fst);
|
||||
void operator()();
|
||||
|
||||
private:
|
||||
void OnSpeechStart();
|
||||
void OnSpeechEnd();
|
||||
void OnFinish();
|
||||
void OnSpeechData();
|
||||
void OnPartialResult();
|
||||
void OnFinalResult();
|
||||
void DecodeThreadFunc();
|
||||
void SerializeResult(bool finish);
|
||||
|
||||
bool continuous_decoding_ = false;
|
||||
int nbest_ = 1;
|
||||
ServerReaderWriter<Response, Request> *stream_;
|
||||
std::shared_ptr<Request> request_;
|
||||
std::shared_ptr<Response> response_;
|
||||
std::shared_ptr<FeaturePipelineConfig> feature_config_;
|
||||
std::shared_ptr<DecodeOptions> decode_config_;
|
||||
std::shared_ptr<fst::SymbolTable> symbol_table_;
|
||||
std::shared_ptr<TorchAsrModel> model_;
|
||||
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
|
||||
|
||||
bool got_start_tag_ = false;
|
||||
bool got_end_tag_ = false;
|
||||
// When endpoint is detected, stop recognition, and stop receiving data.
|
||||
bool stop_recognition_ = false;
|
||||
std::shared_ptr<FeaturePipeline> feature_pipeline_ = nullptr;
|
||||
std::shared_ptr<TorchAsrDecoder> decoder_ = nullptr;
|
||||
std::shared_ptr<std::thread> decode_thread_ = nullptr;
|
||||
};
|
||||
|
||||
class GrpcServer final : public ASR::Service {
|
||||
public:
|
||||
GrpcServer(std::shared_ptr<FeaturePipelineConfig> feature_config,
|
||||
std::shared_ptr<DecodeOptions> decode_config,
|
||||
std::shared_ptr<fst::SymbolTable> symbol_table,
|
||||
std::shared_ptr<TorchAsrModel> model,
|
||||
std::shared_ptr<fst::Fst<fst::StdArc>> fst)
|
||||
: feature_config_(std::move(feature_config)),
|
||||
decode_config_(std::move(decode_config)),
|
||||
symbol_table_(std::move(symbol_table)),
|
||||
model_(std::move(model)),
|
||||
fst_(std::move(fst)) {}
|
||||
Status Recognize(ServerContext *context,
|
||||
ServerReaderWriter<Response, Request> *reader) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<FeaturePipelineConfig> feature_config_;
|
||||
std::shared_ptr<DecodeOptions> decode_config_;
|
||||
std::shared_ptr<fst::SymbolTable> symbol_table_;
|
||||
std::shared_ptr<TorchAsrModel> model_;
|
||||
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
|
||||
DISALLOW_COPY_AND_ASSIGN(GrpcServer);
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // GRPC_GRPC_SERVER_H_
|
66
language_model/runtime/core/grpc/wenet.proto
Normal file
66
language_model/runtime/core/grpc/wenet.proto
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
option java_package = "ex.grpc";
|
||||
option objc_class_prefix = "wenet";
|
||||
|
||||
package wenet;
|
||||
|
||||
service ASR {
|
||||
rpc Recognize (stream Request) returns (stream Response) {}
|
||||
}
|
||||
|
||||
message Request {
|
||||
|
||||
message DecodeConfig {
|
||||
int32 nbest_config = 1;
|
||||
bool continuous_decoding_config = 2;
|
||||
}
|
||||
|
||||
oneof RequestPayload {
|
||||
DecodeConfig decode_config = 1;
|
||||
bytes audio_data = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message Response {
|
||||
|
||||
message OneBest {
|
||||
string sentence = 1;
|
||||
repeated OnePiece wordpieces = 2;
|
||||
}
|
||||
|
||||
message OnePiece {
|
||||
string word = 1;
|
||||
int32 start = 2;
|
||||
int32 end = 3;
|
||||
}
|
||||
|
||||
enum Status {
|
||||
ok = 0;
|
||||
failed = 1;
|
||||
}
|
||||
|
||||
enum Type {
|
||||
server_ready = 0;
|
||||
partial_result = 1;
|
||||
final_result = 2;
|
||||
speech_end = 3;
|
||||
}
|
||||
|
||||
Status status = 1;
|
||||
Type type = 2;
|
||||
repeated OneBest nbest = 3;
|
||||
}
|
Reference in New Issue
Block a user