import random import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from trainer.inject import Injector from utils.util import opt_get class MelSpectrogramInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) from models.tacotron2.layers import TacotronSTFT # These are the default tacotron values for the MEL spectrogram. filter_length = opt_get(opt, ['filter_length'], 1024) hop_length = opt_get(opt, ['hop_length'], 256) win_length = opt_get(opt, ['win_length'], 1024) n_mel_channels = opt_get(opt, ['n_mel_channels'], 80) mel_fmin = opt_get(opt, ['mel_fmin'], 0) mel_fmax = opt_get(opt, ['mel_fmax'], 8000) sampling_rate = opt_get(opt, ['sampling_rate'], 22050) self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax) def forward(self, state): inp = state[self.input] if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) inp = inp.squeeze(1) assert len(inp.shape) == 2 self.stft = self.stft.to(inp.device) return {self.output: self.stft.mel_spectrogram(inp)} class TorchMelSpectrogramInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) # These are the default tacotron values for the MEL spectrogram. self.filter_length = opt_get(opt, ['filter_length'], 1024) self.hop_length = opt_get(opt, ['hop_length'], 256) self.win_length = opt_get(opt, ['win_length'], 1024) self.n_mel_channels = opt_get(opt, ['n_mel_channels'], 80) self.mel_fmin = opt_get(opt, ['mel_fmin'], 0) self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000) self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) norm = opt_get(opt, ['normalize'], False) self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, win_length=self.win_length, power=2, normalized=norm, sample_rate=self.sampling_rate, f_min=self.mel_fmin, f_max=self.mel_fmax, n_mels=self.n_mel_channels, norm="slaney") self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None) if self.mel_norm_file is not None: self.mel_norms = torch.load(self.mel_norm_file) else: self.mel_norms = None def forward(self, state): inp = state[self.input] if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) inp = inp.squeeze(1) assert len(inp.shape) == 2 self.mel_stft = self.mel_stft.to(inp.device) mel = self.mel_stft(inp) # Perform dynamic range compression mel = torch.log(torch.clamp(mel, min=1e-5)) if self.mel_norms is not None: self.mel_norms = self.mel_norms.to(mel.device) mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) return {self.output: mel} class RandomAudioCropInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) self.crop_sz = opt['crop_size'] def forward(self, state): inp = state[self.input] len = inp.shape[-1] margin = len - self.crop_sz start = random.randint(0, margin) return {self.output: inp[:, :, start:start+self.crop_sz]} class AudioClipInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) self.clip_size = opt['clip_size'] self.ctc_codes = opt['ctc_codes_key'] self.output_ctc = opt['ctc_out_key'] def forward(self, state): inp = state[self.input] ctc = state[self.ctc_codes] len = inp.shape[-1] if len > self.clip_size: proportion_inp_remaining = self.clip_size/len inp = inp[:, :, :self.clip_size] ctc = ctc[:,:int(proportion_inp_remaining*ctc.shape[-1])] return {self.output: inp, self.output_ctc: ctc} class AudioResampleInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) self.input_sr = opt['input_sample_rate'] self.output_sr = opt['output_sample_rate'] def forward(self, state): inp = state[self.input] return {self.output: torchaudio.functional.resample(inp, self.input_sr, self.output_sr)}