156 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			156 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import numpy as np
 | |
| import re
 | |
| from g2p_en import G2p
 | |
| 
 | |
| 
 | |
| 
 | |
| LOGIT_PHONE_DEF = [
 | |
|     'BLANK', 'SIL', # blank and silence
 | |
|     'AA', 'AE', 'AH', 'AO', 'AW',
 | |
|     'AY', 'B',  'CH', 'D', 'DH',
 | |
|     'EH', 'ER', 'EY', 'F', 'G',
 | |
|     'HH', 'IH', 'IY', 'JH', 'K',
 | |
|     'L', 'M', 'N', 'NG', 'OW',
 | |
|     'OY', 'P', 'R', 'S', 'SH',
 | |
|     'T', 'TH', 'UH', 'UW', 'V',
 | |
|     'W', 'Y', 'Z', 'ZH'
 | |
| ]
 | |
| SIL_DEF = ['SIL']
 | |
| 
 | |
| 
 | |
| # remove puntuation from text
 | |
| def remove_punctuation(sentence):
 | |
|     # Remove punctuation
 | |
|     sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
 | |
|     sentence = sentence.replace('--', '').lower()
 | |
|     sentence = sentence.replace(" '", "'").lower()
 | |
| 
 | |
|     sentence = sentence.strip()
 | |
|     sentence = ' '.join(sentence.split())
 | |
| 
 | |
|     return sentence
 | |
| 
 | |
| 
 | |
| # Convert RNN logits to argmax phonemes
 | |
| def logits_to_phonemes(logits):
 | |
|     seq = np.argmax(logits, axis=1)
 | |
|     seq2 = np.array([seq[0]] + [seq[i] for i in range(1, len(seq)) if seq[i] != seq[i-1]])
 | |
| 
 | |
|     phones = []
 | |
|     for i in range(len(seq2)):
 | |
|         phones.append(LOGIT_PHONE_DEF[seq2[i]])
 | |
| 
 | |
|     # Remove blank and repeated phonemes
 | |
|     phones = [p for p in phones if  p!='BLANK']
 | |
|     phones = [phones[0]] + [phones[i] for i in range(1, len(phones)) if phones[i] != phones[i-1]]
 | |
| 
 | |
|     return phones
 | |
| 
 | |
| 
 | |
| # Convert text to phonemes
 | |
| def sentence_to_phonemes(thisTranscription, g2p_instance=None):
 | |
|     if not g2p_instance:
 | |
|         g2p_instance = G2p()
 | |
| 
 | |
|     # Remove punctuation
 | |
|     thisTranscription = remove_punctuation(thisTranscription)
 | |
| 
 | |
|     # Convert to phonemes
 | |
|     phonemes = []
 | |
|     if len(thisTranscription) == 0:
 | |
|         phonemes = SIL_DEF
 | |
|     else:
 | |
|         for p in g2p_instance(thisTranscription):
 | |
|             if p==' ':
 | |
|                 phonemes.append('SIL')
 | |
| 
 | |
|             p = re.sub(r'[0-9]', '', p)  # Remove stress
 | |
|             if re.match(r'[A-Z]+', p):  # Only keep phonemes
 | |
|                 phonemes.append(p)
 | |
| 
 | |
|         #add one SIL symbol at the end so there's one at the end of each word
 | |
|         phonemes.append('SIL')
 | |
|     
 | |
|     return phonemes, thisTranscription
 | |
| 
 | |
| 
 | |
| # Calculate WER or PER
 | |
| def calculate_error_rate(r, h):
 | |
|     """
 | |
|     Calculation of WER or PER with Levenshtein distance.
 | |
|     Works only for iterables up to 254 elements (uint8).
 | |
|     O(nm) time ans space complexity.
 | |
|     ----------
 | |
|     Parameters:
 | |
|     r : list of true words or phonemes
 | |
|     h : list of predicted words or phonemes
 | |
|     ----------
 | |
|     Returns:
 | |
|     Word error rate (WER) or phoneme error rate (PER) [int]
 | |
|     ----------
 | |
|     Examples:
 | |
|     >>> calculate_wer("who is there".split(), "is there".split())
 | |
|     1
 | |
|     >>> calculate_wer("who is there".split(), "".split())
 | |
|     3
 | |
|     >>> calculate_wer("".split(), "who is there".split())
 | |
|     3
 | |
|     """
 | |
|     # initialization
 | |
|     d = np.zeros((len(r)+1)*(len(h)+1), dtype=np.uint8)
 | |
|     d = d.reshape((len(r)+1, len(h)+1))
 | |
|     for i in range(len(r)+1):
 | |
|         for j in range(len(h)+1):
 | |
|             if i == 0:
 | |
|                 d[0][j] = j
 | |
|             elif j == 0:
 | |
|                 d[i][0] = i
 | |
| 
 | |
|     # computation
 | |
|     for i in range(1, len(r)+1):
 | |
|         for j in range(1, len(h)+1):
 | |
|             if r[i-1] == h[j-1]:
 | |
|                 d[i][j] = d[i-1][j-1]
 | |
|             else:
 | |
|                 substitution = d[i-1][j-1] + 1
 | |
|                 insertion    = d[i][j-1] + 1
 | |
|                 deletion     = d[i-1][j] + 1
 | |
|                 d[i][j] = min(substitution, insertion, deletion)
 | |
| 
 | |
|     return d[len(r)][len(h)]
 | |
| 
 | |
| 
 | |
| # calculate aggregate WER or PER
 | |
| def calculate_aggregate_error_rate(r, h):
 | |
| 
 | |
|     # list setup
 | |
|     err_count = []
 | |
|     item_count = []
 | |
|     error_rate_ind = []
 | |
| 
 | |
|     # calculate individual error rates
 | |
|     for x in range(len(h)):
 | |
|         r_x = r[x]
 | |
|         h_x = h[x]
 | |
| 
 | |
|         n_err = calculate_error_rate(r_x, h_x)
 | |
| 
 | |
|         item_count.append(len(r_x))
 | |
|         err_count.append(n_err)
 | |
|         error_rate_ind.append(n_err / len(r_x))
 | |
| 
 | |
|     # Calculate aggregate error rate
 | |
|     error_rate_agg = np.sum(err_count) / np.sum(item_count)
 | |
| 
 | |
|     # calculate 95% CI
 | |
|     item_count = np.array(item_count)
 | |
|     err_count = np.array(err_count)
 | |
|     nResamples = 10000
 | |
|     resampled_error_rate = np.zeros([nResamples,])
 | |
|     for n in range(nResamples):
 | |
|         resampleIdx = np.random.randint(0, item_count.shape[0], [item_count.shape[0]])
 | |
|         resampled_error_rate[n] = np.sum(err_count[resampleIdx]) / np.sum(item_count[resampleIdx])
 | |
|     error_rate_agg_CI = np.percentile(resampled_error_rate, [2.5, 97.5])
 | |
| 
 | |
|     # return everything as a tuple
 | |
|     return (error_rate_agg, error_rate_agg_CI[0], error_rate_agg_CI[1], error_rate_ind) | 
