311 lines
8.6 KiB
Python
311 lines
8.6 KiB
Python
![]() |
import sys
|
||
|
import random
|
||
|
import math
|
||
|
|
||
|
import torchaudio
|
||
|
import torch
|
||
|
torchaudio.set_audio_backend("sox_io")
|
||
|
|
||
|
|
||
|
def db2amp(db):
|
||
|
return pow(10, db / 20)
|
||
|
|
||
|
def amp2db(amp):
|
||
|
return 20 * math.log10(amp)
|
||
|
|
||
|
def make_poly_distortion(conf):
|
||
|
"""Generate a db-domain ploynomial distortion function
|
||
|
|
||
|
f(x) = a * x^m * (1-x)^n + x
|
||
|
|
||
|
Args:
|
||
|
conf: a dict {'a': #int, 'm': #int, 'n': #int}
|
||
|
|
||
|
Returns:
|
||
|
The ploynomial function, which could be applied on
|
||
|
a float amplitude value
|
||
|
"""
|
||
|
a = conf['a']
|
||
|
m = conf['m']
|
||
|
n = conf['n']
|
||
|
|
||
|
def poly_distortion(x):
|
||
|
abs_x = abs(x)
|
||
|
if abs_x < 0.000001:
|
||
|
x = x
|
||
|
else:
|
||
|
db_norm = amp2db(abs_x) / 100 + 1
|
||
|
if db_norm < 0:
|
||
|
db_norm = 0
|
||
|
db_norm = a * pow(db_norm, m) * pow((1 - db_norm), n) + db_norm
|
||
|
if db_norm > 1:
|
||
|
db_norm = 1
|
||
|
db = (db_norm - 1) * 100
|
||
|
amp = db2amp(db)
|
||
|
if amp >= 0.9997:
|
||
|
amp = 0.9997
|
||
|
if x > 0:
|
||
|
x = amp
|
||
|
else:
|
||
|
x = -amp
|
||
|
return x
|
||
|
return poly_distortion
|
||
|
|
||
|
def make_quad_distortion():
|
||
|
return make_poly_distortion({'a' : 1, 'm' : 1, 'n' : 1})
|
||
|
|
||
|
# the amplitude are set to max for all non-zero point
|
||
|
def make_max_distortion(conf):
|
||
|
"""Generate a max distortion function
|
||
|
|
||
|
Args:
|
||
|
conf: a dict {'max_db': float }
|
||
|
'max_db': the maxium value.
|
||
|
|
||
|
Returns:
|
||
|
The max function, which could be applied on
|
||
|
a float amplitude value
|
||
|
"""
|
||
|
max_db = conf['max_db']
|
||
|
if max_db:
|
||
|
max_amp = db2amp(max_db) # < 0.997
|
||
|
else:
|
||
|
max_amp = 0.997
|
||
|
|
||
|
def max_distortion(x):
|
||
|
if x > 0:
|
||
|
x = max_amp
|
||
|
elif x < 0:
|
||
|
x = -max_amp
|
||
|
else:
|
||
|
x = 0.0
|
||
|
return x
|
||
|
return max_distortion
|
||
|
|
||
|
|
||
|
|
||
|
def make_amp_mask(db_mask=None):
|
||
|
"""Get a amplitude domain mask from db domain mask
|
||
|
|
||
|
Args:
|
||
|
db_mask: Optional. A list of tuple. if None, using default value.
|
||
|
|
||
|
Returns:
|
||
|
A list of tuple. The amplitude domain mask
|
||
|
"""
|
||
|
if db_mask is None:
|
||
|
db_mask = [(-110, -95), (-90, -80), (-65, -60), (-50, -30), (-15, 0)]
|
||
|
amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask]
|
||
|
return amp_mask
|
||
|
|
||
|
default_mask = make_amp_mask()
|
||
|
|
||
|
|
||
|
def generate_amp_mask(mask_num):
|
||
|
"""Generate amplitude domain mask randomly in [-100db, 0db]
|
||
|
|
||
|
Args:
|
||
|
mask_num: the slot number of the mask
|
||
|
|
||
|
Returns:
|
||
|
A list of tuple. each tuple defines a slot.
|
||
|
e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)]
|
||
|
for #mask_num = 4
|
||
|
"""
|
||
|
a = [0] * 2 * mask_num
|
||
|
a[0] = 0
|
||
|
m = []
|
||
|
for i in range(1, 2 * mask_num):
|
||
|
a[i] = a[i - 1] + random.uniform(0.5, 1)
|
||
|
max_val = a[2 * mask_num - 1]
|
||
|
for i in range(0, mask_num):
|
||
|
l = ((a[2 * i] - max_val) / max_val) * 100
|
||
|
r = ((a[2 * i + 1] - max_val) / max_val) * 100
|
||
|
m.append((l, r))
|
||
|
return make_amp_mask(m)
|
||
|
|
||
|
|
||
|
def make_fence_distortion(conf):
|
||
|
"""Generate a fence distortion function
|
||
|
|
||
|
In this fence-like shape function, the values in mask slots are
|
||
|
set to maxium, while the values not in mask slots are set to 0.
|
||
|
Use seperated masks for Positive and negetive amplitude.
|
||
|
|
||
|
Args:
|
||
|
conf: a dict {'mask_number': int,'max_db': float }
|
||
|
'mask_number': the slot number in mask.
|
||
|
'max_db': the maxium value.
|
||
|
|
||
|
Returns:
|
||
|
The fence function, which could be applied on
|
||
|
a float amplitude value
|
||
|
"""
|
||
|
mask_number = conf['mask_number']
|
||
|
max_db = conf['max_db']
|
||
|
max_amp = db2amp(max_db) # 0.997
|
||
|
if mask_number <= 0 :
|
||
|
positive_mask = default_mask
|
||
|
negative_mask = make_amp_mask([(-50, 0)])
|
||
|
else:
|
||
|
positive_mask = generate_amp_mask(mask_number)
|
||
|
negative_mask = generate_amp_mask(mask_number)
|
||
|
|
||
|
def fence_distortion(x):
|
||
|
is_in_mask = False
|
||
|
if x > 0:
|
||
|
for mask in positive_mask:
|
||
|
if x >= mask[0] and x <= mask[1]:
|
||
|
is_in_mask = True
|
||
|
return max_amp
|
||
|
if not is_in_mask:
|
||
|
return 0.0
|
||
|
elif x < 0:
|
||
|
abs_x = abs(x)
|
||
|
for mask in negative_mask:
|
||
|
if abs_x >= mask[0] and abs_x <= mask[1]:
|
||
|
is_in_mask = True
|
||
|
return max_amp
|
||
|
if not is_in_mask:
|
||
|
return 0.0
|
||
|
return x
|
||
|
|
||
|
return fence_distortion
|
||
|
|
||
|
#
|
||
|
def make_jag_distortion(conf):
|
||
|
"""Generate a jag distortion function
|
||
|
|
||
|
In this jag-like shape function, the values in mask slots are
|
||
|
not changed, while the values not in mask slots are set to 0.
|
||
|
Use seperated masks for Positive and negetive amplitude.
|
||
|
|
||
|
Args:
|
||
|
conf: a dict {'mask_number': #int}
|
||
|
'mask_number': the slot number in mask.
|
||
|
|
||
|
Returns:
|
||
|
The jag function,which could be applied on
|
||
|
a float amplitude value
|
||
|
"""
|
||
|
mask_number = conf['mask_number']
|
||
|
if mask_number <= 0 :
|
||
|
positive_mask = default_mask
|
||
|
negative_mask = make_amp_mask([(-50, 0)])
|
||
|
else:
|
||
|
positive_mask = generate_amp_mask(mask_number)
|
||
|
negative_mask = generate_amp_mask(mask_number)
|
||
|
|
||
|
def jag_distortion(x):
|
||
|
is_in_mask = False
|
||
|
if x > 0:
|
||
|
for mask in positive_mask:
|
||
|
if x >= mask[0] and x <= mask[1]:
|
||
|
is_in_mask = True
|
||
|
return x
|
||
|
if not is_in_mask:
|
||
|
return 0.0
|
||
|
elif x < 0:
|
||
|
abs_x = abs(x)
|
||
|
for mask in negative_mask:
|
||
|
if abs_x >= mask[0] and abs_x <= mask[1]:
|
||
|
is_in_mask = True
|
||
|
return x
|
||
|
if not is_in_mask:
|
||
|
return 0.0
|
||
|
return x
|
||
|
|
||
|
return jag_distortion
|
||
|
|
||
|
# gaining 20db means amp = amp * 10
|
||
|
# gaining -20db means amp = amp / 10
|
||
|
def make_gain_db(conf):
|
||
|
"""Generate a db domain gain function
|
||
|
|
||
|
Args:
|
||
|
conf: a dict {'db': #float}
|
||
|
'db': the gaining value
|
||
|
|
||
|
Returns:
|
||
|
The db gain function, which could be applied on
|
||
|
a float amplitude value
|
||
|
"""
|
||
|
db = conf['db']
|
||
|
|
||
|
def gain_db(x):
|
||
|
return min(0.997, x * pow(10, db / 20))
|
||
|
|
||
|
return gain_db
|
||
|
|
||
|
|
||
|
def distort(x, func, rate=0.8):
|
||
|
"""Distort a waveform in sample point level
|
||
|
|
||
|
Args:
|
||
|
x: the origin wavefrom
|
||
|
func: the distort function
|
||
|
rate: sample point-level distort probability
|
||
|
|
||
|
Returns:
|
||
|
the distorted waveform
|
||
|
"""
|
||
|
for i in range(0, x.shape[1]):
|
||
|
a = random.uniform(0, 1)
|
||
|
if a < rate:
|
||
|
x[0][i] = func(float(x[0][i]))
|
||
|
return x
|
||
|
|
||
|
def distort_chain(x, funcs, rate=0.8):
|
||
|
for i in range(0, x.shape[1]):
|
||
|
a = random.uniform(0, 1)
|
||
|
if a < rate:
|
||
|
for func in funcs:
|
||
|
x[0][i] = func(float(x[0][i]))
|
||
|
return x
|
||
|
|
||
|
# x is numpy
|
||
|
def distort_wav_conf(x, distort_type, distort_conf, rate=0.1):
|
||
|
if distort_type == 'gain_db':
|
||
|
gain_db = make_gain_db(distort_conf)
|
||
|
x = distort(x, gain_db)
|
||
|
elif distort_type == 'max_distortion':
|
||
|
max_distortion = make_max_distortion(distort_conf)
|
||
|
x = distort(x, max_distortion, rate=rate)
|
||
|
elif distort_type == 'fence_distortion':
|
||
|
fence_distortion = make_fence_distortion(distort_conf)
|
||
|
x = distort(x, fence_distortion, rate=rate)
|
||
|
elif distort_type == 'jag_distortion':
|
||
|
jag_distortion = make_jag_distortion(distort_conf)
|
||
|
x = distort(x, jag_distortion, rate=rate)
|
||
|
elif distort_type == 'poly_distortion':
|
||
|
poly_distortion = make_poly_distortion(distort_conf)
|
||
|
x = distort(x, poly_distortion, rate=rate)
|
||
|
elif distort_type == 'quad_distortion':
|
||
|
quad_distortion = make_quad_distortion()
|
||
|
x = distort(x, quad_distortion, rate=rate)
|
||
|
elif distort_type == 'none_distortion':
|
||
|
pass
|
||
|
else:
|
||
|
print('unsupport type')
|
||
|
return x
|
||
|
|
||
|
def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, wav_out):
|
||
|
x, sr = torchaudio.load(wav_in)
|
||
|
x = x.detach().numpy()
|
||
|
out = distort_wav_conf(x, distort_type, distort_conf, rate)
|
||
|
torchaudio.save(wav_out, torch.from_numpy(out), sr)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
distort_type = sys.argv[1]
|
||
|
wav_in = sys.argv[2]
|
||
|
wav_out = sys.argv[3]
|
||
|
conf = None
|
||
|
rate = 0.1
|
||
|
if distort_type == 'new_jag_distortion':
|
||
|
conf = {'mask_number' : 4}
|
||
|
elif distort_type == 'new_fence_distortion':
|
||
|
conf = {'mask_number' : 1, 'max_db' : -30}
|
||
|
elif distort_type == 'poly_distortion':
|
||
|
conf = {'a' : 4, 'm' : 2, "n" : 2}
|
||
|
distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out)
|