Add spec_augment injector
This commit is contained in:
parent
4cff774b0e
commit
993bd52d42
|
@ -13,10 +13,10 @@ class ResBlock(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm1d(chan),
|
nn.GroupNorm(chan//8, chan),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm1d(chan)
|
nn.GroupNorm(chan//8, chan)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -31,11 +31,13 @@ class MelEncoder(nn.Module):
|
||||||
ResBlock(channels//4),
|
ResBlock(channels//4),
|
||||||
ResBlock(channels//4),
|
ResBlock(channels//4),
|
||||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||||
nn.BatchNorm1d(channels//2),
|
nn.GroupNorm(channels//16, channels//2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
ResBlock(channels//2),
|
ResBlock(channels//2),
|
||||||
ResBlock(channels//2),
|
ResBlock(channels//2),
|
||||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.GroupNorm(channels//8, channels),
|
||||||
|
nn.ReLU(),
|
||||||
ResBlock(channels),
|
ResBlock(channels),
|
||||||
ResBlock(channels)
|
ResBlock(channels)
|
||||||
)
|
)
|
||||||
|
@ -48,7 +50,7 @@ class GptAsrHf(nn.Module):
|
||||||
NUMBER_SYMBOLS = len(symbols)
|
NUMBER_SYMBOLS = len(symbols)
|
||||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
|
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
|
||||||
|
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000):
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000, checkpointing=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
||||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||||
|
@ -64,7 +66,9 @@ class GptAsrHf(nn.Module):
|
||||||
n_ctx=seq_length,
|
n_ctx=seq_length,
|
||||||
n_embd=model_dim,
|
n_embd=model_dim,
|
||||||
n_layer=layers,
|
n_layer=layers,
|
||||||
n_head=heads))
|
n_head=heads,
|
||||||
|
gradient_checkpointing=checkpointing,
|
||||||
|
use_cache=not checkpointing))
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||||
|
|
||||||
|
|
65
codes/trainer/injectors/spec_augment.py
Normal file
65
codes/trainer/injectors/spec_augment.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
# Original source: https://github.com/SeanNaren/deepspeech.pytorch/blob/master/deepspeech_pytorch/loader/sparse_image_warp.py
|
||||||
|
# Removes the time_warp augmentation and only implements masking.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import torchvision.utils
|
||||||
|
|
||||||
|
from trainer.inject import Injector
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def spec_augment(mel_spectrogram, frequency_masking_para=27, time_masking_para=70, frequency_mask_num=1, time_mask_num=1):
|
||||||
|
|
||||||
|
v = mel_spectrogram.shape[1]
|
||||||
|
tau = mel_spectrogram.shape[2]
|
||||||
|
|
||||||
|
# Step 2 : Frequency masking
|
||||||
|
for i in range(frequency_mask_num):
|
||||||
|
f = np.random.uniform(low=0.0, high=frequency_masking_para)
|
||||||
|
f = int(f)
|
||||||
|
if v - f < 0:
|
||||||
|
continue
|
||||||
|
f0 = random.randint(0, v-f)
|
||||||
|
mel_spectrogram[:, f0:f0+f, :] = 0
|
||||||
|
|
||||||
|
# Step 3 : Time masking
|
||||||
|
for i in range(time_mask_num):
|
||||||
|
t = np.random.uniform(low=0.0, high=time_masking_para)
|
||||||
|
t = int(t)
|
||||||
|
if tau - t < 0:
|
||||||
|
continue
|
||||||
|
t0 = random.randint(0, tau-t)
|
||||||
|
mel_spectrogram[:, :, t0:t0+t] = 0
|
||||||
|
|
||||||
|
return mel_spectrogram
|
||||||
|
|
||||||
|
class MelMaskInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.freq_mask_sz = opt_get(opt, ['frequency_mask_size_high'], 27)
|
||||||
|
self.n_freq_masks = opt_get(opt, ['frequency_mask_count'], 1)
|
||||||
|
self.time_mask_sz = opt_get(opt, ['time_mask_size_high'], 5)
|
||||||
|
self.n_time_masks = opt_get(opt, ['time_mask_count'], 3)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
h = state[self.input]
|
||||||
|
return {self.output: spec_augment(h, self.freq_mask_sz, self.time_mask_sz, self.n_freq_masks, self.n_time_masks)}
|
||||||
|
|
||||||
|
def visualization_spectrogram(spec, title):
|
||||||
|
# Turns spec into an image and outputs it to the filesystem.
|
||||||
|
spec = spec.unsqueeze(dim=1)
|
||||||
|
# Normalize so spectrogram is easier to view.
|
||||||
|
spec = (spec - spec.mean()) / spec.std()
|
||||||
|
spec = ((spec + 1) / 2).clip(0, 1)
|
||||||
|
torchvision.utils.save_image(spec, f'{title}.png')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
|
from trainer.injectors.base_injectors import MelSpectrogramInjector
|
||||||
|
spec_maker = MelSpectrogramInjector({'in': 'audio', 'out': 'spec'}, {})
|
||||||
|
a = load_audio('D:\\data\\audio\\libritts\\test-clean\\61\\70970\\61_70970_000007_000001.wav', 22050).unsqueeze(0)
|
||||||
|
s = spec_maker({'audio': a})['spec']
|
||||||
|
visualization_spectrogram(s, 'original spec')
|
||||||
|
saug = spec_augment(s, 50, 5, 1, 3)
|
||||||
|
visualization_spectrogram(saug, 'modified spec')
|
Loading…
Reference in New Issue
Block a user