diff --git a/codes/models/gpt_voice/gpt_asr_hf.py b/codes/models/gpt_voice/gpt_asr_hf.py index 9b487441..41121d09 100644 --- a/codes/models/gpt_voice/gpt_asr_hf.py +++ b/codes/models/gpt_voice/gpt_asr_hf.py @@ -13,10 +13,10 @@ class ResBlock(nn.Module): super().__init__() self.net = nn.Sequential( nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.BatchNorm1d(chan), + nn.GroupNorm(chan//8, chan), nn.ReLU(), nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.BatchNorm1d(chan) + nn.GroupNorm(chan//8, chan) ) def forward(self, x): @@ -31,11 +31,13 @@ class MelEncoder(nn.Module): ResBlock(channels//4), ResBlock(channels//4), nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), - nn.BatchNorm1d(channels//2), + nn.GroupNorm(channels//16, channels//2), nn.ReLU(), ResBlock(channels//2), ResBlock(channels//2), nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels//8, channels), + nn.ReLU(), ResBlock(channels), ResBlock(channels) ) @@ -48,7 +50,7 @@ class GptAsrHf(nn.Module): NUMBER_SYMBOLS = len(symbols) NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1 - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000): + def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000, checkpointing=True): super().__init__() self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding. self.max_symbols_per_phrase = max_symbols_per_phrase @@ -64,7 +66,9 @@ class GptAsrHf(nn.Module): n_ctx=seq_length, n_embd=model_dim, n_layer=layers, - n_head=heads)) + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing)) self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) diff --git a/codes/trainer/injectors/spec_augment.py b/codes/trainer/injectors/spec_augment.py new file mode 100644 index 00000000..fa03ec71 --- /dev/null +++ b/codes/trainer/injectors/spec_augment.py @@ -0,0 +1,65 @@ +# Original source: https://github.com/SeanNaren/deepspeech.pytorch/blob/master/deepspeech_pytorch/loader/sparse_image_warp.py +# Removes the time_warp augmentation and only implements masking. + +import numpy as np +import random +import torchvision.utils + +from trainer.inject import Injector +from utils.util import opt_get + + +def spec_augment(mel_spectrogram, frequency_masking_para=27, time_masking_para=70, frequency_mask_num=1, time_mask_num=1): + + v = mel_spectrogram.shape[1] + tau = mel_spectrogram.shape[2] + + # Step 2 : Frequency masking + for i in range(frequency_mask_num): + f = np.random.uniform(low=0.0, high=frequency_masking_para) + f = int(f) + if v - f < 0: + continue + f0 = random.randint(0, v-f) + mel_spectrogram[:, f0:f0+f, :] = 0 + + # Step 3 : Time masking + for i in range(time_mask_num): + t = np.random.uniform(low=0.0, high=time_masking_para) + t = int(t) + if tau - t < 0: + continue + t0 = random.randint(0, tau-t) + mel_spectrogram[:, :, t0:t0+t] = 0 + + return mel_spectrogram + +class MelMaskInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.freq_mask_sz = opt_get(opt, ['frequency_mask_size_high'], 27) + self.n_freq_masks = opt_get(opt, ['frequency_mask_count'], 1) + self.time_mask_sz = opt_get(opt, ['time_mask_size_high'], 5) + self.n_time_masks = opt_get(opt, ['time_mask_count'], 3) + + def forward(self, state): + h = state[self.input] + return {self.output: spec_augment(h, self.freq_mask_sz, self.time_mask_sz, self.n_freq_masks, self.n_time_masks)} + +def visualization_spectrogram(spec, title): + # Turns spec into an image and outputs it to the filesystem. + spec = spec.unsqueeze(dim=1) + # Normalize so spectrogram is easier to view. + spec = (spec - spec.mean()) / spec.std() + spec = ((spec + 1) / 2).clip(0, 1) + torchvision.utils.save_image(spec, f'{title}.png') + +if __name__ == '__main__': + from data.audio.unsupervised_audio_dataset import load_audio + from trainer.injectors.base_injectors import MelSpectrogramInjector + spec_maker = MelSpectrogramInjector({'in': 'audio', 'out': 'spec'}, {}) + a = load_audio('D:\\data\\audio\\libritts\\test-clean\\61\\70970\\61_70970_000007_000001.wav', 22050).unsqueeze(0) + s = spec_maker({'audio': a})['spec'] + visualization_spectrogram(s, 'original spec') + saug = spec_augment(s, 50, 5, 1, 3) + visualization_spectrogram(saug, 'modified spec')