DL-Art-School/codes/trainer/injectors/spec_augment.py

66 lines
2.5 KiB
Python
Raw Normal View History

2021-11-02 00:43:11 +00:00
# Original source: https://github.com/SeanNaren/deepspeech.pytorch/blob/master/deepspeech_pytorch/loader/sparse_image_warp.py
# Removes the time_warp augmentation and only implements masking.
import numpy as np
import random
import torchvision.utils
from trainer.inject import Injector
from utils.util import opt_get
def spec_augment(mel_spectrogram, frequency_masking_para=27, time_masking_para=70, frequency_mask_num=1, time_mask_num=1):
v = mel_spectrogram.shape[1]
tau = mel_spectrogram.shape[2]
# Step 2 : Frequency masking
for i in range(frequency_mask_num):
f = np.random.uniform(low=0.0, high=frequency_masking_para)
f = int(f)
if v - f < 0:
continue
f0 = random.randint(0, v-f)
mel_spectrogram[:, f0:f0+f, :] = 0
# Step 3 : Time masking
for i in range(time_mask_num):
t = np.random.uniform(low=0.0, high=time_masking_para)
t = int(t)
if tau - t < 0:
continue
t0 = random.randint(0, tau-t)
mel_spectrogram[:, :, t0:t0+t] = 0
return mel_spectrogram
class MelMaskInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.freq_mask_sz = opt_get(opt, ['frequency_mask_size_high'], 27)
self.n_freq_masks = opt_get(opt, ['frequency_mask_count'], 1)
self.time_mask_sz = opt_get(opt, ['time_mask_size_high'], 5)
self.n_time_masks = opt_get(opt, ['time_mask_count'], 3)
def forward(self, state):
h = state[self.input]
return {self.output: spec_augment(h, self.freq_mask_sz, self.time_mask_sz, self.n_freq_masks, self.n_time_masks)}
def visualization_spectrogram(spec, title):
# Turns spec into an image and outputs it to the filesystem.
spec = spec.unsqueeze(dim=1)
# Normalize so spectrogram is easier to view.
spec = (spec - spec.mean()) / spec.std()
spec = ((spec + 1) / 2).clip(0, 1)
torchvision.utils.save_image(spec, f'{title}.png')
if __name__ == '__main__':
from data.audio.unsupervised_audio_dataset import load_audio
from trainer.injectors.base_injectors import MelSpectrogramInjector
spec_maker = MelSpectrogramInjector({'in': 'audio', 'out': 'spec'}, {})
a = load_audio('D:\\data\\audio\\libritts\\test-clean\\61\\70970\\61_70970_000007_000001.wav', 22050).unsqueeze(0)
s = spec_maker({'audio': a})['spec']
visualization_spectrogram(s, 'original spec')
saug = spec_augment(s, 50, 5, 1, 3)
visualization_spectrogram(saug, 'modified spec')