44 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			44 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| 
 | |
| import numpy as np
 | |
| import lm_decoder
 | |
| 
 | |
| decode_opts = lm_decoder.DecodeOptions(
 | |
|     7000,   # max_active
 | |
|     200,    # min_active
 | |
|     17.,    # beam 
 | |
|     8.,     # lattice_beam
 | |
|     1.0,    # acoustic_scale
 | |
|     0.98,   # ctc_blank_skip_threshold
 | |
|     10      # nbest
 | |
| )
 | |
| 
 | |
| model_path = '/oak/stanford/groups/shenoy/stfan/code/nptlrig2/LanguageModelDecoder/examples/handwriting/s0/3gram_no_prune/data/lang_test'
 | |
| decode_resource = lm_decoder.DecodeResource(
 | |
|     os.path.join(model_path, 'TLG.fst'),
 | |
|     "",
 | |
|     "",
 | |
|     os.path.join(model_path, 'words.txt'),
 | |
|     ""
 | |
| )
 | |
| decoder = lm_decoder.BrainSpeechDecoder(decode_resource, decode_opts)
 | |
| 
 | |
| # Load handwriting RNN logits output
 | |
| logits = np.load('test_logits.npy')
 | |
| print(logits.shape)
 | |
| 
 | |
| # Rearrange logits to Kaldi character order
 | |
| # [ctc_blank, ">", ",", "?", "~", "'", a, b, ..., z]
 | |
| char_range = list(range(0, 26))
 | |
| logits = logits[:, :, [31] + [26, 27, 30, 29, 28] + char_range]
 | |
| 
 | |
| # Decode
 | |
| for i in range(logits.shape[0]):
 | |
|     lm_decoder.DecodeNumpy(decoder, logits[i])
 | |
|     decoder.FinishDecoding()
 | |
|     if len(decoder.result()) > 0:
 | |
|         print(decoder.result()[0].sentence)
 | |
|     else:
 | |
|         print("No result")
 | |
|     decoder.Reset()
 | 
