From de1a1d501a1c09f6963edb6f2aeb505d1f0f17c1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 3 Feb 2022 21:42:37 -0700 Subject: [PATCH] Move audio injectors into their own file --- codes/trainer/injectors/audio_injectors.py | 112 +++++++++++++++++++++ codes/trainer/injectors/base_injectors.py | 110 +------------------- 2 files changed, 113 insertions(+), 109 deletions(-) create mode 100644 codes/trainer/injectors/audio_injectors.py diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py new file mode 100644 index 00000000..8ce5209c --- /dev/null +++ b/codes/trainer/injectors/audio_injectors.py @@ -0,0 +1,112 @@ +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from trainer.inject import Injector +from utils.util import opt_get + + +class MelSpectrogramInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + from models.tacotron2.layers import TacotronSTFT + # These are the default tacotron values for the MEL spectrogram. + filter_length = opt_get(opt, ['filter_length'], 1024) + hop_length = opt_get(opt, ['hop_length'], 256) + win_length = opt_get(opt, ['win_length'], 1024) + n_mel_channels = opt_get(opt, ['n_mel_channels'], 80) + mel_fmin = opt_get(opt, ['mel_fmin'], 0) + mel_fmax = opt_get(opt, ['mel_fmax'], 8000) + sampling_rate = opt_get(opt, ['sampling_rate'], 22050) + self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax) + + def forward(self, state): + inp = state[self.input] + if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + self.stft = self.stft.to(inp.device) + return {self.output: self.stft.mel_spectrogram(inp)} + + +class TorchMelSpectrogramInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + # These are the default tacotron values for the MEL spectrogram. + self.filter_length = opt_get(opt, ['filter_length'], 1024) + self.hop_length = opt_get(opt, ['hop_length'], 256) + self.win_length = opt_get(opt, ['win_length'], 1024) + self.n_mel_channels = opt_get(opt, ['n_mel_channels'], 80) + self.mel_fmin = opt_get(opt, ['mel_fmin'], 0) + self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000) + self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) + norm = opt_get(opt, ['normalize'], False) + self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, + win_length=self.win_length, power=2, normalized=norm, + sample_rate=self.sampling_rate, f_min=self.mel_fmin, + f_max=self.mel_fmax, n_mels=self.n_mel_channels, + norm="slaney") + self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None) + if self.mel_norm_file is not None: + self.mel_norms = torch.load(self.mel_norm_file) + else: + self.mel_norms = None + + def forward(self, state): + inp = state[self.input] + if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + self.mel_stft = self.mel_stft.to(inp.device) + mel = self.mel_stft(inp) + # Perform dynamic range compression + mel = torch.log(torch.clamp(mel, min=1e-5)) + if self.mel_norms is not None: + self.mel_norms = self.mel_norms.to(mel.device) + mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) + return {self.output: mel} + + +class RandomAudioCropInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.crop_sz = opt['crop_size'] + + def forward(self, state): + inp = state[self.input] + len = inp.shape[-1] + margin = len - self.crop_sz + start = random.randint(0, margin) + return {self.output: inp[:, :, start:start+self.crop_sz]} + + +class AudioClipInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.clip_size = opt['clip_size'] + self.ctc_codes = opt['ctc_codes_key'] + self.output_ctc = opt['ctc_out_key'] + + def forward(self, state): + inp = state[self.input] + ctc = state[self.ctc_codes] + len = inp.shape[-1] + if len > self.clip_size: + proportion_inp_remaining = self.clip_size/len + inp = inp[:, :, :self.clip_size] + ctc = ctc[:,:int(proportion_inp_remaining*ctc.shape[-1])] + return {self.output: inp, self.output_ctc: ctc} + + +class AudioResampleInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.input_sr = opt['input_sample_rate'] + self.output_sr = opt['output_sample_rate'] + + def forward(self, state): + inp = state[self.input] + return {self.output: torchaudio.functional.resample(inp, self.input_sr, self.output_sr)} diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 5acb61f6..15cd211c 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -1,14 +1,11 @@ import random import torch.nn -import torchaudio.functional from kornia.augmentation import RandomResizedCrop from torch.cuda.amp import autocast -from data.audio.unsupervised_audio_dataset import load_audio from trainer.inject import Injector, create_injector from trainer.losses import extract_params_from_state -from utils.audio import plot_spectrogram from utils.util import opt_get from utils.weight_scheduler import get_scheduler_for_opt @@ -567,109 +564,4 @@ class DenormalizeInjector(Injector): def forward(self, state): inp = state[self.input] out = inp * self.scale + self.shift - return {self.output: out} - - -class MelSpectrogramInjector(Injector): - def __init__(self, opt, env): - super().__init__(opt, env) - from models.tacotron2.layers import TacotronSTFT - # These are the default tacotron values for the MEL spectrogram. - filter_length = opt_get(opt, ['filter_length'], 1024) - hop_length = opt_get(opt, ['hop_length'], 256) - win_length = opt_get(opt, ['win_length'], 1024) - n_mel_channels = opt_get(opt, ['n_mel_channels'], 80) - mel_fmin = opt_get(opt, ['mel_fmin'], 0) - mel_fmax = opt_get(opt, ['mel_fmax'], 8000) - sampling_rate = opt_get(opt, ['sampling_rate'], 22050) - self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax) - - def forward(self, state): - inp = state[self.input] - if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) - inp = inp.squeeze(1) - assert len(inp.shape) == 2 - self.stft = self.stft.to(inp.device) - return {self.output: self.stft.mel_spectrogram(inp)} - - -class TorchMelSpectrogramInjector(Injector): - def __init__(self, opt, env): - super().__init__(opt, env) - # These are the default tacotron values for the MEL spectrogram. - self.filter_length = opt_get(opt, ['filter_length'], 1024) - self.hop_length = opt_get(opt, ['hop_length'], 256) - self.win_length = opt_get(opt, ['win_length'], 1024) - self.n_mel_channels = opt_get(opt, ['n_mel_channels'], 80) - self.mel_fmin = opt_get(opt, ['mel_fmin'], 0) - self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000) - self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) - norm = opt_get(opt, ['normalize'], False) - self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, - win_length=self.win_length, power=2, normalized=norm, - sample_rate=self.sampling_rate, f_min=self.mel_fmin, - f_max=self.mel_fmax, n_mels=self.n_mel_channels, - norm="slaney") - self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None) - if self.mel_norm_file is not None: - self.mel_norms = torch.load(self.mel_norm_file) - else: - self.mel_norms = None - - def forward(self, state): - inp = state[self.input] - if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) - inp = inp.squeeze(1) - assert len(inp.shape) == 2 - self.mel_stft = self.mel_stft.to(inp.device) - mel = self.mel_stft(inp) - # Perform dynamic range compression - mel = torch.log(torch.clamp(mel, min=1e-5)) - if self.mel_norms is not None: - self.mel_norms = self.mel_norms.to(mel.device) - mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) - return {self.output: mel} - - -def test_torch_mel_injector(): - a = load_audio('D:\\data\\audio\\libritts\\train-clean-100\\19\\198\\19_198_000000_000000.wav', 22050) - inj = TorchMelSpectrogramInjector({'in': 'in', 'out': 'out', 'mel_norm_file': '../experiments/clips_mel_norms.pth'}, {}) - f = inj({'in': a.unsqueeze(0)})['out'] - plot_spectrogram(f[0]) - inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {}) - t = inj({'in': a.unsqueeze(0)})['out'] - plot_spectrogram(t[0]) - print('Pause') - - -class RandomAudioCropInjector(Injector): - def __init__(self, opt, env): - super().__init__(opt, env) - self.crop_sz = opt['crop_size'] - - def forward(self, state): - inp = state[self.input] - len = inp.shape[-1] - margin = len - self.crop_sz - start = random.randint(0, margin) - return {self.output: inp[:, :, start:start+self.crop_sz]} - - -class AudioResampleInjector(Injector): - def __init__(self, opt, env): - super().__init__(opt, env) - self.input_sr = opt['input_sample_rate'] - self.output_sr = opt['output_sample_rate'] - - def forward(self, state): - inp = state[self.input] - return {self.output: torchaudio.functional.resample(inp, self.input_sr, self.output_sr)} - - -def test_audio_resample_injector(): - inj = AudioResampleInjector({'in': 'x', 'out': 'y', 'input_sample_rate': 22050, 'output_sample_rate': '1'}, None) - print(inj({'x':torch.rand(10,1,40800)})['y'].shape) - - -if __name__ == '__main__': - test_torch_mel_injector() \ No newline at end of file + return {self.output: out} \ No newline at end of file