2025-07-02 12:18:09 -07:00
import redis
import argparse
import numpy as np
from datetime import datetime
import time
import os
import re
import logging
import torch
import lm_decoder
from functools import lru_cache
from transformers import AutoModelForCausalLM , AutoTokenizer
2025-07-02 14:28:34 -07:00
# set up logging
2025-07-02 12:18:09 -07:00
logging . basicConfig ( format = ' %(asctime)s %(levelname)s : %(message)s ' , level = logging . INFO )
2025-07-02 14:28:34 -07:00
# function for initializing the ngram decoder
2025-07-02 12:18:09 -07:00
def build_lm_decoder (
model_path ,
max_active = 7000 ,
min_active = 200 ,
beam = 17. ,
lattice_beam = 8.0 ,
acoustic_scale = 1.5 ,
ctc_blank_skip_threshold = 1.0 ,
length_penalty = 0.0 ,
nbest = 1 ,
) :
decode_opts = lm_decoder . DecodeOptions (
max_active ,
min_active ,
beam ,
lattice_beam ,
acoustic_scale ,
ctc_blank_skip_threshold ,
length_penalty ,
nbest
)
TLG_path = os . path . join ( model_path , ' TLG.fst ' )
words_path = os . path . join ( model_path , ' words.txt ' )
G_path = os . path . join ( model_path , ' G.fst ' )
rescore_G_path = os . path . join ( model_path , ' G_no_prune.fst ' )
if not os . path . exists ( rescore_G_path ) :
rescore_G_path = " "
G_path = " "
if not os . path . exists ( TLG_path ) :
raise ValueError ( ' TLG file not found at {} ' . format ( TLG_path ) )
if not os . path . exists ( words_path ) :
raise ValueError ( ' words file not found at {} ' . format ( words_path ) )
decode_resource = lm_decoder . DecodeResource (
TLG_path ,
G_path ,
rescore_G_path ,
words_path ,
" "
)
decoder = lm_decoder . BrainSpeechDecoder ( decode_resource , decode_opts )
return decoder
2025-07-02 14:28:34 -07:00
# function for updating the ngram decoder parameters
2025-07-02 12:18:09 -07:00
def update_ngram_params (
ngramDecoder ,
max_active = 200 ,
min_active = 17.0 ,
beam = 13.0 ,
lattice_beam = 8.0 ,
acoustic_scale = 1.5 ,
ctc_blank_skip_threshold = 1.0 ,
length_penalty = 0.0 ,
nbest = 100 ,
) :
decode_opts = lm_decoder . DecodeOptions (
max_active ,
min_active ,
beam ,
lattice_beam ,
acoustic_scale ,
ctc_blank_skip_threshold ,
length_penalty ,
nbest ,
)
ngramDecoder . SetOpt ( decode_opts )
2025-07-02 14:28:34 -07:00
# function for initializing the OPT model and tokenizer
2025-07-02 12:18:09 -07:00
def build_opt (
model_name = ' facebook/opt-6.7b ' ,
cache_dir = None ,
device = ' cuda ' if torch . cuda . is_available ( ) else ' cpu ' ,
) :
'''
Load the OPT - 6.7 b model and tokenizer from Hugging Face .
We will load the model with 16 - bit precision for faster inference . This requires ~ 13 GB of VRAM .
Put the model onto the GPU ( if available ) .
'''
# load tokenizer and model
tokenizer = AutoTokenizer . from_pretrained ( model_name , cache_dir = cache_dir )
model = AutoModelForCausalLM . from_pretrained (
model_name ,
cache_dir = cache_dir ,
torch_dtype = torch . float16 ,
)
if device != ' cpu ' :
# Move the model to the GPU
model = model . to ( device )
# Set the model to evaluation mode
model . eval ( )
# ensure padding token
tokenizer . padding_side = " right "
tokenizer . pad_token = tokenizer . eos_token
return model , tokenizer
2025-07-02 14:28:34 -07:00
# function for rescoring hypotheses with the GPT-2 model
2025-07-02 12:18:09 -07:00
@torch.inference_mode ( )
def rescore_with_gpt2 (
model ,
tokenizer ,
device ,
hypotheses ,
length_penalty
) :
# set model to evaluation mode
model . eval ( )
inputs = tokenizer ( hypotheses , return_tensors = ' pt ' , padding = True )
inputs = { k : v . to ( device ) for k , v in inputs . items ( ) }
outputs = model ( * * inputs )
# compute log-probabilities
log_probs = torch . nn . functional . log_softmax ( outputs . logits , dim = - 1 )
log_probs = log_probs . cpu ( ) . numpy ( )
input_ids = inputs [ ' input_ids ' ] . cpu ( ) . numpy ( )
attention_mask = inputs [ ' attention_mask ' ] . cpu ( ) . numpy ( )
batch_size , seq_len , _ = log_probs . shape
scores = [ ]
for i in range ( batch_size ) :
n_tokens = int ( attention_mask [ i ] . sum ( ) )
# sum log-probs of each token given the previous context
score = sum (
log_probs [ i , t - 1 , input_ids [ i , t ] ]
for t in range ( 1 , n_tokens )
)
scores . append ( score - n_tokens * length_penalty )
return scores
2025-07-02 14:28:34 -07:00
# function for decoding with the GPT-2 model
2025-07-02 12:18:09 -07:00
def gpt2_lm_decode (
model ,
tokenizer ,
device ,
nbest ,
acoustic_scale ,
length_penalty ,
alpha ,
returnConfidence = False ,
current_context_str = None ,
) :
hypotheses = [ ]
acousticScores = [ ]
oldLMScores = [ ]
for out in nbest :
# get the candidate sentence (hypothesis)
hyp = out [ 0 ] . strip ( )
if len ( hyp ) == 0 :
continue
# add context to the front of each sentence
if current_context_str is not None and len ( current_context_str . split ( ) ) > 0 :
hyp = current_context_str + ' ' + hyp
hyp = hyp . replace ( ' > ' , ' ' )
hyp = hyp . replace ( ' ' , ' ' )
hyp = hyp . replace ( ' , ' , ' , ' )
hyp = hyp . replace ( ' . ' , ' . ' )
hyp = hyp . replace ( ' ? ' , ' ? ' )
hypotheses . append ( hyp )
acousticScores . append ( out [ 1 ] )
oldLMScores . append ( out [ 2 ] )
if len ( hypotheses ) == 0 :
logging . error ( ' In g2p_lm_decode, len(hypotheses) == 0 ' )
return ( " " , [ ] ) if not returnConfidence else ( " " , [ ] , 0. )
# convert to numpy arrays
acousticScores = np . array ( acousticScores )
oldLMScores = np . array ( oldLMScores )
# get new LM scores from LLM
try :
# first, try to rescore all at once
newLMScores = np . array ( rescore_with_gpt2 ( model , tokenizer , device , hypotheses , length_penalty ) )
except Exception as e :
logging . error ( f ' Error during OPT rescore: { e } ' )
try :
# if that fails, try to rescore in batches (to avoid VRAM issues)
newLMScores = [ ]
for i in range ( 0 , len ( hypotheses ) , int ( np . ceil ( len ( hypotheses ) / 5 ) ) ) :
newLMScores . extend ( rescore_with_gpt2 ( model , tokenizer , device , hypotheses [ i : i + int ( np . ceil ( len ( hypotheses ) / 5 ) ) ] , length_penalty ) )
newLMScores = np . array ( newLMScores )
except Exception as e :
logging . error ( f ' Error during OPT rescore: { e } ' )
newLMScores = np . zeros ( len ( hypotheses ) )
# remove context from start of each sentence
if current_context_str is not None and len ( current_context_str . split ( ) ) > 0 :
hypotheses = [ h [ ( len ( current_context_str ) + 1 ) : ] for h in hypotheses ]
# calculate total scores
totalScores = ( acoustic_scale * acousticScores ) + ( ( 1 - alpha ) * oldLMScores ) + ( alpha * newLMScores )
# get the best hypothesis
maxIdx = np . argmax ( totalScores )
bestHyp = hypotheses [ maxIdx ]
# create nbest output
nbest_out = [ ]
min_len = np . min ( ( len ( nbest ) , len ( newLMScores ) , len ( totalScores ) ) )
for i in range ( min_len ) :
nbest_out . append ( ' ; ' . join ( map ( str , [ nbest [ i ] [ 0 ] , nbest [ i ] [ 1 ] , nbest [ i ] [ 2 ] , newLMScores [ i ] , totalScores [ i ] ] ) ) )
# return
if not returnConfidence :
return bestHyp , nbest_out
else :
totalScores = totalScores - np . max ( totalScores )
probs = np . exp ( totalScores )
return bestHyp , nbest_out , probs [ maxIdx ] / np . sum ( probs )
def connect_to_redis_server ( redis_ip , redis_port ) :
try :
# logging.info("Attempting to connect to redis...")
redis_conn = redis . Redis ( host = redis_ip , port = redis_port )
redis_conn . ping ( )
except redis . exceptions . ConnectionError :
logging . warning ( " Can ' t connect to redis server (ConnectionError). " )
return
else :
logging . info ( " Connected to redis. " )
return redis_conn
def get_current_redis_time_ms ( redis_conn ) :
t = redis_conn . time ( )
return int ( t [ 0 ] * 1000 + t [ 1 ] / 1000 )
2025-07-02 14:28:34 -07:00
# function to get string differences between two sentences
2025-07-02 12:18:09 -07:00
def get_string_differences ( cue , decoder_output ) :
decoder_output_words = decoder_output . split ( )
cue_words = cue . split ( )
@lru_cache ( None )
def reverse_w_backtrace ( i , j ) :
if i == 0 :
return j , [ ' I ' ] * j
elif j == 0 :
return i , [ ' D ' ] * i
elif i > 0 and j > 0 and decoder_output_words [ i - 1 ] == cue_words [ j - 1 ] :
cost , path = reverse_w_backtrace ( i - 1 , j - 1 )
return cost , path + [ i - 1 ]
else :
insertion_cost , insertion_path = reverse_w_backtrace ( i , j - 1 )
deletion_cost , deletion_path = reverse_w_backtrace ( i - 1 , j )
substitution_cost , substitution_path = reverse_w_backtrace ( i - 1 , j - 1 )
if insertion_cost < = deletion_cost and insertion_cost < = substitution_cost :
return insertion_cost + 1 , insertion_path + [ ' I ' ]
elif deletion_cost < = insertion_cost and deletion_cost < = substitution_cost :
return deletion_cost + 1 , deletion_path + [ ' D ' ]
else :
return substitution_cost + 1 , substitution_path + [ ' R ' ]
cost , path = reverse_w_backtrace ( len ( decoder_output_words ) , len ( cue_words ) )
# remove insertions from path
path = [ p for p in path if p != ' I ' ]
# Get the indices in decoder_output of the words that are different from cue
indices_to_highlight = [ ]
current_index = 0
for label , word in zip ( path , decoder_output_words ) :
if label in [ ' R ' , ' D ' ] :
indices_to_highlight . append ( ( current_index , current_index + len ( word ) ) )
current_index + = len ( word ) + 1
return cost , path , indices_to_highlight
def remove_punctuation ( sentence ) :
# Remove punctuation
sentence = re . sub ( r ' [^a-zA-Z \ - \' ] ' , ' ' , sentence )
sentence = sentence . replace ( ' - ' , ' ' ) . lower ( )
sentence = sentence . replace ( ' -- ' , ' ' ) . lower ( )
sentence = sentence . replace ( " ' " , " ' " ) . lower ( )
sentence = sentence . strip ( )
sentence = ' ' . join ( sentence . split ( ) )
return sentence
2025-07-02 14:28:34 -07:00
# function to augment the nbest list by swapping words around, artificially increasing the number of candidates
2025-07-02 12:18:09 -07:00
def augment_nbest ( nbest , top_candidates_to_augment = 20 , acoustic_scale = 0.3 , score_penalty_percent = 0.01 ) :
sentences = [ ]
ac_scores = [ ]
lm_scores = [ ]
total_scores = [ ]
for i in range ( len ( nbest ) ) :
sentences . append ( nbest [ i ] [ 0 ] . strip ( ) )
ac_scores . append ( nbest [ i ] [ 1 ] )
lm_scores . append ( nbest [ i ] [ 2 ] )
total_scores . append ( acoustic_scale * nbest [ i ] [ 1 ] + nbest [ i ] [ 2 ] )
# sort by total score
sorted_indices = np . argsort ( total_scores ) [ : : - 1 ]
sentences = [ sentences [ i ] for i in sorted_indices ]
ac_scores = [ ac_scores [ i ] for i in sorted_indices ]
lm_scores = [ lm_scores [ i ] for i in sorted_indices ]
total_scores = [ total_scores [ i ] for i in sorted_indices ]
# new sentences and scores
new_sentences = [ ]
new_ac_scores = [ ]
new_lm_scores = [ ]
new_total_scores = [ ]
# swap words around
for i1 in range ( np . min ( [ len ( sentences ) - 1 , top_candidates_to_augment ] ) ) :
words1 = sentences [ i1 ] . split ( )
for i2 in range ( i1 + 1 , np . min ( [ len ( sentences ) , top_candidates_to_augment ] ) ) :
words2 = sentences [ i2 ] . split ( )
if len ( words1 ) != len ( words2 ) :
continue
_ , path1 , _ = get_string_differences ( sentences [ i1 ] , sentences [ i2 ] )
_ , path2 , _ = get_string_differences ( sentences [ i2 ] , sentences [ i1 ] )
replace_indices1 = [ i for i , p in enumerate ( path2 ) if p == ' R ' ]
replace_indices2 = [ i for i , p in enumerate ( path1 ) if p == ' R ' ]
for r1 , r2 in zip ( replace_indices1 , replace_indices2 ) :
new_words1 = words1 . copy ( )
new_words2 = words2 . copy ( )
new_words1 [ r1 ] = words2 [ r2 ]
new_words2 [ r2 ] = words1 [ r1 ]
new_sentence1 = ' ' . join ( new_words1 )
new_sentence2 = ' ' . join ( new_words2 )
if new_sentence1 not in sentences and new_sentence1 not in new_sentences :
new_sentences . append ( new_sentence1 )
new_ac_scores . append ( np . mean ( [ ac_scores [ i1 ] , ac_scores [ i2 ] ] ) - score_penalty_percent * np . abs ( np . mean ( [ ac_scores [ i1 ] , ac_scores [ i2 ] ] ) ) )
new_lm_scores . append ( np . mean ( [ lm_scores [ i1 ] , lm_scores [ i2 ] ] ) - score_penalty_percent * np . abs ( np . mean ( [ lm_scores [ i1 ] , lm_scores [ i2 ] ] ) ) )
new_total_scores . append ( acoustic_scale * new_ac_scores [ - 1 ] + new_lm_scores [ - 1 ] )
if new_sentence2 not in sentences and new_sentence2 not in new_sentences :
new_sentences . append ( new_sentence2 )
new_ac_scores . append ( np . mean ( [ ac_scores [ i1 ] , ac_scores [ i2 ] ] ) - score_penalty_percent * np . abs ( np . mean ( [ ac_scores [ i1 ] , ac_scores [ i2 ] ] ) ) )
new_lm_scores . append ( np . mean ( [ lm_scores [ i1 ] , lm_scores [ i2 ] ] ) - score_penalty_percent * np . abs ( np . mean ( [ lm_scores [ i1 ] , lm_scores [ i2 ] ] ) ) )
new_total_scores . append ( acoustic_scale * new_ac_scores [ - 1 ] + new_lm_scores [ - 1 ] )
# combine new sentences and scores with old
for i in range ( len ( new_sentences ) ) :
sentences . append ( new_sentences [ i ] )
ac_scores . append ( new_ac_scores [ i ] )
lm_scores . append ( new_lm_scores [ i ] )
total_scores . append ( new_total_scores [ i ] )
# sort by total score
sorted_indices = np . argsort ( total_scores ) [ : : - 1 ]
sentences = [ sentences [ i ] for i in sorted_indices ]
ac_scores = [ ac_scores [ i ] for i in sorted_indices ]
lm_scores = [ lm_scores [ i ] for i in sorted_indices ]
total_scores = [ total_scores [ i ] for i in sorted_indices ]
# return nbest
nbest_out = [ ]
for i in range ( len ( sentences ) ) :
nbest_out . append ( [ sentences [ i ] , ac_scores [ i ] , lm_scores [ i ] ] )
return nbest_out
# main function
def main ( args ) :
lm_path = args . lm_path
gpu_number = args . gpu_number
max_active = args . max_active
min_active = args . min_active
beam = args . beam
lattice_beam = args . lattice_beam
acoustic_scale = args . acoustic_scale
ctc_blank_skip_threshold = args . ctc_blank_skip_threshold
length_penalty = args . length_penalty
nbest = args . nbest
top_candidates_to_augment = args . top_candidates_to_augment
score_penalty_percent = args . score_penalty_percent
blank_penalty = args . blank_penalty
do_opt = args . do_opt # acoustic scale = 0.8, blank penalty = 7, alpha = 0.5
opt_cache_dir = args . opt_cache_dir
alpha = args . alpha
rescore = args . rescore
redis_ip = args . redis_ip
redis_port = args . redis_port
input_stream = args . input_stream
partial_output_stream = args . partial_output_stream
final_output_stream = args . final_output_stream
# expand user on paths
lm_path = os . path . expanduser ( lm_path )
if not os . path . exists ( lm_path ) :
raise ValueError ( f ' Language model path does not exist: { lm_path } ' )
if opt_cache_dir is not None :
opt_cache_dir = os . path . expanduser ( opt_cache_dir )
# create a nice dict of params to put into redis
lm_args = {
' lm_path ' : lm_path ,
2025-07-03 14:05:04 -07:00
' max_active ' : int ( max_active ) ,
' min_active ' : int ( min_active ) ,
' beam ' : float ( beam ) ,
' lattice_beam ' : float ( lattice_beam ) ,
' acoustic_scale ' : float ( acoustic_scale ) ,
' ctc_blank_skip_threshold ' : float ( ctc_blank_skip_threshold ) ,
' length_penalty ' : float ( length_penalty ) ,
' nbest ' : int ( nbest ) ,
' blank_penalty ' : float ( blank_penalty ) ,
' alpha ' : float ( alpha ) ,
' do_opt ' : int ( do_opt ) ,
' rescore ' : int ( rescore ) ,
' top_candidates_to_augment ' : int ( top_candidates_to_augment ) ,
' score_penalty_percent ' : float ( score_penalty_percent ) ,
2025-07-02 12:18:09 -07:00
}
# pick GPU
device = torch . device ( f " cuda: { gpu_number } " if torch . cuda . is_available ( ) else " cpu " )
logging . info ( f ' Using device: { device } ' )
# initialize opt model
if do_opt :
logging . info ( f " Building opt model from { opt_cache_dir } ... " )
start_time = time . time ( )
lm , lm_tokenizer = build_opt (
cache_dir = opt_cache_dir ,
device = device ,
)
logging . info ( f ' OPT model successfully built in { ( time . time ( ) - start_time ) : 0.4f } seconds. ' )
# initialize ngram decoder
logging . info ( f ' Initializing language model decoder from { lm_path } ... ' )
start_time = time . time ( )
ngramDecoder = build_lm_decoder (
lm_path ,
max_active = 7000 ,
min_active = 200 ,
beam = 17. ,
lattice_beam = 8. ,
acoustic_scale = acoustic_scale ,
ctc_blank_skip_threshold = 1.0 ,
length_penalty = 0.0 ,
nbest = nbest ,
)
logging . info ( f ' Language model successfully initialized in { ( time . time ( ) - start_time ) : 0.4f } seconds. ' )
# connect to redis server
REDIS_STATE = - 1
logging . info ( f ' Attempting to connect to redis at { redis_ip } : { redis_port } ... ' )
r = connect_to_redis_server ( redis_ip , redis_port )
while r is None :
r = connect_to_redis_server ( redis_ip , redis_port )
if r is None :
logging . warning ( f ' At startup, could not connect to redis server at { redis_ip } : { redis_port } . Trying again in 3 seconds... ' )
time . sleep ( 3 )
logging . info ( f ' Successfully connected to redis server at { redis_ip } : { redis_port } . ' )
timeout_ms = 100
oldStr = ' '
prev_loop_start_time = 0
# main loop
logging . info ( ' Entering main loop... ' )
while True :
# make sure that the loop doesn't run too fast (max 1000 Hz)
loop_time = time . time ( ) - prev_loop_start_time
if loop_time < 0.001 :
time . sleep ( 0.001 - loop_time )
prev_loop_start_time = time . time ( )
# try catch is to make sure we're connected to redis, and reconnect if not
try :
r . ping ( )
except redis . exceptions . ConnectionError :
if REDIS_STATE != 0 :
logging . error ( f ' Could not connect to the redis server at at { redis_ip } : { redis_port } ! I will keep trying... ' )
REDIS_STATE = 0
time . sleep ( 1 )
continue
else :
if REDIS_STATE != 1 :
logging . info ( ' Successfully connected to the redis server. ' )
logits_last_entry_seen = get_current_redis_time_ms ( r )
reset_last_entry_seen = get_current_redis_time_ms ( r )
finalize_last_entry_seen = get_current_redis_time_ms ( r )
update_params_last_entry_seen = get_current_redis_time_ms ( r )
REDIS_STATE = 1
# if the 'remote_lm_args' stream is empty, add the current args
# (this makes sure it's re-added once redis is flushed at the start of a new block)
if r . xlen ( ' remote_lm_args ' ) == 0 :
r . xadd ( ' remote_lm_args ' , lm_args )
# check if we need to reset
lm_reset_stream = r . xread (
{ ' remote_lm_reset ' : reset_last_entry_seen } ,
count = 1 ,
block = None ,
)
if len ( lm_reset_stream ) > 0 :
for entry_id , entry_data in lm_reset_stream [ 0 ] [ 1 ] :
reset_last_entry_seen = entry_id
# Reset the language model and tell redis, then move on to the next loop
oldStr = ' '
ngramDecoder . Reset ( )
r . xadd ( ' remote_lm_done_resetting ' , { ' done ' : 1 } )
logging . info ( ' Reset the language model. ' )
continue
# check if we need to finalize
lm_finalize_stream = r . xread (
{ ' remote_lm_finalize ' : finalize_last_entry_seen } ,
count = 1 ,
block = None ,
)
if len ( lm_finalize_stream ) > 0 :
for entry_id , entry_data in lm_finalize_stream [ 0 ] [ 1 ] :
finalize_last_entry_seen = entry_id
if r . get ( ' contextual_decoding_current_context ' ) is not None :
current_context_str = r . get ( ' contextual_decoding_current_context ' ) . decode ( ) . strip ( )
if len ( current_context_str . split ( ) ) > 0 :
logging . info ( f ' For LLM rescore, adding context str to the beginning of each candidate sentence: ' )
logging . info ( f ' \t " { current_context_str } " ' )
else :
current_context_str = ' '
# Finalize decoding, add the output to the output stream, and then move on to the next loop
ngramDecoder . FinishDecoding ( )
oldStr = ' '
# Optionally rescore with unpruned LM
if rescore :
startT = time . time ( )
ngramDecoder . Rescore ( )
logging . info ( ' Rescore time: %.3f ' % ( time . time ( ) - startT ) )
# if nbest > 1, augment those sentences and bias them toward certain words
if nbest > 1 :
# append the sentence, acoustic score, and lm score to a list
nbest_out = [ ]
for d in ngramDecoder . result ( ) :
nbest_out . append ( [ d . sentence , d . ac_score , d . lm_score ] )
# generate some more candidate sentences by swapping words around
nbest_out_len = len ( nbest_out )
nbest_out = augment_nbest (
nbest = nbest_out ,
top_candidates_to_augment = top_candidates_to_augment ,
acoustic_scale = acoustic_scale ,
score_penalty_percent = score_penalty_percent ,
)
logging . info ( f ' Augmented nbest from { nbest_out_len } to { len ( nbest_out ) } candidates. ' )
# Optionally rescore with a LLM
if do_opt :
startT = time . time ( )
decoded_final , nbest_redis , confidences = gpt2_lm_decode (
lm ,
lm_tokenizer ,
device ,
nbest_out ,
acoustic_scale ,
alpha = alpha ,
length_penalty = length_penalty ,
current_context_str = current_context_str ,
returnConfidence = True ,
)
logging . info ( ' OPT time: %.3f ' % ( time . time ( ) - startT ) )
elif len ( ngramDecoder . result ( ) ) > 0 :
# Otherwise just output the best sentence
decoded_final = ngramDecoder . result ( ) [ 0 ] . sentence
nbest_redis = ' '
else :
logging . error ( ' No output from language model. ' )
decoded_final = ' '
nbest_redis = ' '
logging . info ( f ' Final: { decoded_final } ' )
if nbest > 1 :
r . xadd ( final_output_stream , { ' lm_response_final ' : decoded_final , ' scoring ' : ' ; ' . join ( nbest_redis ) , ' context_str ' : current_context_str } )
else :
r . xadd ( final_output_stream , { ' lm_response_final ' : decoded_final } )
logging . info ( ' Finalized the language model. \n ' )
r . xadd ( ' remote_lm_done_finalizing ' , { ' done ' : 1 } )
continue
# check if we need to update the decoder params
update_params_stream = r . xread (
{ ' remote_lm_update_params ' : update_params_last_entry_seen } ,
count = 1 ,
block = None ,
)
if len ( update_params_stream ) > 0 :
for entry_id , entry_data in update_params_stream [ 0 ] [ 1 ] :
update_params_last_entry_seen = entry_id
max_active = int ( entry_data . get ( b ' max_active ' , max_active ) )
min_active = int ( entry_data . get ( b ' min_active ' , min_active ) )
beam = float ( entry_data . get ( b ' beam ' , beam ) )
lattice_beam = float ( entry_data . get ( b ' lattice_beam ' , lattice_beam ) )
acoustic_scale = float ( entry_data . get ( b ' acoustic_scale ' , acoustic_scale ) )
ctc_blank_skip_threshold = float ( entry_data . get ( b ' ctc_blank_skip_threshold ' , ctc_blank_skip_threshold ) )
length_penalty = float ( entry_data . get ( b ' length_penalty ' , length_penalty ) )
nbest = int ( entry_data . get ( b ' nbest ' , nbest ) )
blank_penalty = float ( entry_data . get ( b ' blank_penalty ' , blank_penalty ) )
alpha = float ( entry_data . get ( b ' alpha ' , alpha ) )
do_opt = int ( entry_data . get ( b ' do_opt ' , do_opt ) )
rescore = int ( entry_data . get ( b ' rescore ' , rescore ) )
top_candidates_to_augment = int ( entry_data . get ( b ' top_candidates_to_augment ' , top_candidates_to_augment ) )
score_penalty_percent = float ( entry_data . get ( b ' score_penalty_percent ' , score_penalty_percent ) )
# make sure that the update remote lm args are put into redis nicely
lm_args = {
' lm_path ' : lm_path ,
2025-07-03 14:05:04 -07:00
' max_active ' : int ( max_active ) ,
' min_active ' : int ( min_active ) ,
' beam ' : float ( beam ) ,
' lattice_beam ' : float ( lattice_beam ) ,
' acoustic_scale ' : float ( acoustic_scale ) ,
' ctc_blank_skip_threshold ' : float ( ctc_blank_skip_threshold ) ,
' length_penalty ' : float ( length_penalty ) ,
' nbest ' : int ( nbest ) ,
' blank_penalty ' : float ( blank_penalty ) ,
' alpha ' : float ( alpha ) ,
' do_opt ' : int ( do_opt ) ,
' rescore ' : int ( rescore ) ,
' top_candidates_to_augment ' : int ( top_candidates_to_augment ) ,
' score_penalty_percent ' : float ( score_penalty_percent ) ,
2025-07-02 12:18:09 -07:00
}
r . xadd ( ' remote_lm_args ' , lm_args )
# update ngram parameters
update_ngram_params (
ngramDecoder ,
max_active = max_active ,
min_active = min_active ,
beam = beam ,
lattice_beam = lattice_beam ,
acoustic_scale = acoustic_scale ,
ctc_blank_skip_threshold = ctc_blank_skip_threshold ,
length_penalty = length_penalty ,
nbest = nbest ,
)
logging . info (
f ' Updated language model params: ' +
f ' \n \t max_active = { max_active } ' +
f ' \n \t min_active = { min_active } ' +
f ' \n \t beam = { beam } ' +
f ' \n \t lattice_beam = { lattice_beam } ' +
f ' \n \t acoustic_scale = { acoustic_scale } ' +
f ' \n \t ctc_blank_skip_threshold = { ctc_blank_skip_threshold } ' +
f ' \n \t length_penalty = { length_penalty } ' +
f ' \n \t nbest = { nbest } ' +
f ' \n \t blank_penalty = { blank_penalty } ' +
f ' \n \t alpha = { alpha } ' +
f ' \n \t do_opt = { do_opt } ' +
f ' \n \t rescore = { rescore } ' +
f ' \n \t top_candidates_to_augment = { top_candidates_to_augment } ' +
2025-07-02 14:28:34 -07:00
f ' \n \t score_penalty_percent = { score_penalty_percent } '
2025-07-02 12:18:09 -07:00
)
r . xadd ( ' remote_lm_done_updating_params ' , { ' done ' : 1 } )
continue
# ------------------------------------------------------------------------------------------------------------------------
# ------------ The loop can only get down to here if we're not finalizing, resetting, or updating params -----------------
# ------------------------------------------------------------------------------------------------------------------------
# try to read logits from redis stream
try :
read_result = r . xread (
{ input_stream : logits_last_entry_seen } ,
count = 1 ,
block = timeout_ms
)
except redis . exceptions . ConnectionError :
if REDIS_STATE != 0 :
logging . error ( f ' Could not connect to the redis server at at { redis_ip } : { redis_port } ! I will keep trying... ' )
REDIS_STATE = 0
time . sleep ( 1 )
continue
if ( len ( read_result ) > = 1 ) :
# --------------- Read input stream --------------------------------
for entry_id , entry_data in read_result [ 0 ] [ 1 ] :
logits_last_entry_seen = entry_id
logits = np . frombuffer ( entry_data [ b ' logits ' ] , dtype = np . float32 )
# reshape logits to (T, 41)
logits = logits . reshape ( - 1 , 41 )
# --------------- Run language model -------------------------------
lm_decoder . DecodeNumpy ( ngramDecoder ,
logits ,
np . zeros_like ( logits ) ,
np . log ( blank_penalty ) )
# display partial decoded sentence if it exists
if len ( ngramDecoder . result ( ) ) > 0 :
decoded_partial = ngramDecoder . result ( ) [ 0 ] . sentence
newStr = f ' Partial: { decoded_partial } '
if oldStr != newStr :
logging . info ( newStr )
oldStr = newStr
else :
logging . info ( ' Partial: [NONE] ' )
decoded_partial = ' '
# print(ngramDecoder.result())
r . xadd ( partial_output_stream , { ' lm_response_partial ' : decoded_partial } )
else :
# timeout if no data received for X ms
# logging.warning(F'No logits came in for {timeout_ms} ms.')
continue
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --lm_path ' , type = str , help = ' Path to language model folder ' )
parser . add_argument ( ' --gpu_number ' , type = int , default = 0 , help = ' GPU number to use ' )
parser . add_argument ( ' --max_active ' , type = int , default = 7000 , help = ' max_active param for LM ' )
parser . add_argument ( ' --min_active ' , type = int , default = 200 , help = ' min_active param for LM ' )
parser . add_argument ( ' --beam ' , type = float , default = 17.0 , help = ' beam param for LM ' )
parser . add_argument ( ' --lattice_beam ' , type = float , default = 8.0 , help = ' lattice_beam param for LM ' )
parser . add_argument ( ' --ctc_blank_skip_threshold ' , type = float , default = 1. , help = ' ctc_blank_skip_threshold param for LM ' )
parser . add_argument ( ' --length_penalty ' , type = float , default = 0.0 , help = ' length_penalty param for LM ' )
parser . add_argument ( ' --acoustic_scale ' , type = float , default = 0.3 , help = ' Acoustic scale for LM ' )
2025-07-02 14:28:34 -07:00
parser . add_argument ( ' --nbest ' , type = int , default = 100 , help = ' # of candidate sentences for LM decoding ' )
2025-07-02 12:18:09 -07:00
parser . add_argument ( ' --top_candidates_to_augment ' , type = int , default = 20 , help = ' # of top candidates to augment ' )
parser . add_argument ( ' --score_penalty_percent ' , type = float , default = 0.01 , help = ' Score penalty percent for augmented candidates ' )
2025-07-02 14:28:34 -07:00
parser . add_argument ( ' --blank_penalty ' , type = float , default = 9.0 , help = ' Blank penalty for LM ' )
2025-07-02 12:18:09 -07:00
parser . add_argument ( ' --rescore ' , action = ' store_true ' , help = ' Use an unpruned ngram model for rescoring? ' )
parser . add_argument ( ' --do_opt ' , action = ' store_true ' , help = ' Use the opt model for rescoring? ' )
parser . add_argument ( ' --opt_cache_dir ' , type = str , default = None , help = ' path to opt cache ' )
2025-07-02 14:28:34 -07:00
parser . add_argument ( ' --alpha ' , type = float , default = 0.5 , help = ' alpha value [0-1]: Higher = more weight on OPT rescore. Lower = more weight on ngram rescore ' )
2025-07-02 12:18:09 -07:00
parser . add_argument ( ' --redis_ip ' , type = str , default = ' 192.168.150.2 ' , help = ' IP of the redis stream (string) ' )
parser . add_argument ( ' --redis_port ' , type = int , default = 6379 , help = ' Port of the redis stream (int) ' )
parser . add_argument ( ' --input_stream ' , type = str , default = " remote_lm_input " , help = ' Input stream containing logits ' )
parser . add_argument ( ' --partial_output_stream ' , type = str , default = " remote_lm_output_partial " , help = ' Output stream containing partial decoded sentences ' )
parser . add_argument ( ' --final_output_stream ' , type = str , default = " remote_lm_output_final " , help = ' Output stream containing final decoded sentences ' )
args = parser . parse_args ( )
main ( args )