from math import sqrt import torch from munch import munchify from torch.autograd import Variable from torch import nn from torch.nn import functional as F from models.diffusion.unet_diffusion import UNetModel from models.tacotron2.layers import ConvNorm, LinearNorm from models.tacotron2.hparams import create_hparams from models.tacotron2.tacotron2 import Prenet, Attention, Encoder from trainer.networks import register_model from models.tacotron2.taco_utils import get_mask_from_lengths from utils.util import opt_get, checkpoint class WavDecoder(nn.Module): def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1): super().__init__() self.dec_channels = dec_channels self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults self.clarifier = UNetModel(image_size=self.K, in_channels=1, model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model. out_channels=2, # 2 channels: eps_pred and variance_pred num_res_blocks=2, attention_resolutions=(16,32), dims=1, dropout=.1, channel_mult=(1,1,1,2,4,8), use_raw_y_as_embedding=True) assert self.K % 64 == 0 # Otherwise the UNetModel breaks. self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels]) self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels) self.attention_layer = Attention(dec_channels, dec_channels, dec_channels) self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1) self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels) self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid') self.dropout_probability = dropout_probability def chunk_wav(self, wav): wavs = list(torch.split(wav, self.K, dim=-1)) # Pad the last chunk as needed. padding_needed = self.K - wavs[-1].shape[-1] if padding_needed > 0: wavs[-1] = F.pad(wavs[-1], (0,padding_needed)) wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length return wavs, padding_needed def prepare_decoder_inputs(self, inp): # inp.shape = (b,s,K) chunked waveform. b,s,K = inp.shape first_frame = torch.zeros(b,1,K).to(inp.device) x = torch.cat([first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing. return x def initialize_decoder_states(self, memory, mask): """ Initializes attention rnn states, decoder rnn states, attention weights, attention cumulative weights, attention context, stores memory and stores processed memory PARAMS ------ memory: Encoder outputs mask: Mask for padded data if training, expects None for inference """ B = memory.size(0) MAX_TIME = memory.size(1) self.attention_hidden = Variable(memory.data.new(B, self.dec_channels).zero_()) self.attention_cell = Variable(memory.data.new(B, self.dec_channels).zero_()) self.decoder_hidden = Variable(memory.data.new(B, self.dec_channels).zero_()) self.decoder_cell = Variable(memory.data.new(B, self.dec_channels).zero_()) self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_()) self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_()) self.attention_context = Variable(memory.data.new(B, self.dec_channels).zero_()) self.memory = memory self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory) self.mask = mask def teardown_states(self): self.attention_hidden = None self.attention_cell = None self.decoder_hidden = None self.decoder_cell = None self.attention_weights = None self.attention_weights_cum = None self.attention_context = None self.memory = None self.processed_memory = None def produce_context(self, decoder_input): """ Produces a context and a stop token prediction using the built-in RNN. PARAMS ------ decoder_input: prior diffusion step that has been resolved. RETURNS ------- mel_output: gate_output: gate output energies attention_weights: """ cell_input = torch.cat((decoder_input, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell)) self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability, self.training) attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell)) self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability, self.training) decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1) decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context) gate_prediction = self.gate_layer(decoder_hidden_attention_context) return decoder_output, gate_prediction, self.attention_weights def recombine(self, diffusion_eps, gate_outputs, alignments, padding_added): # (T_out, B) -> (B, T_out) alignments = torch.stack(alignments, dim=1).repeat(1, self.K, 1) # (T_out, B) -> (B, T_out) gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K) b,s,_,K = diffusion_eps.shape # (B, S, 2, K) -> (B, 2, S*K) diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K) return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added] def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths): ''' Performs a training forward pass with the given data. :param wav: (b,n) diffused waveform tensor on the interval [-1,1] :param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over :param text_enc: (b,e) embedding post-encoder with e=self.dec_channels ''' # Start by splitting up the provided waveforms into discrete segments. wavs, padding_added = self.chunk_wav(wav) wavs_corrected, _ = self.chunk_wav(wav_corrected) wavs_corrected = self.prepare_decoder_inputs(wavs_corrected) wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected) self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths)) decoder_contexts, gate_outputs, alignments = [], [], [] while len(decoder_contexts) < wavs_corrected.size(1): decoder_input = wavs_corrected[:, len(decoder_contexts)] dec_context, gate_output, attention_weights = self.produce_context(decoder_input) decoder_contexts += [dec_context.squeeze(1)] gate_outputs += [gate_output.squeeze(1)] alignments += [attention_weights] self.teardown_states() # diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension diffusion_emb = torch.stack(decoder_contexts, dim=1) b,s,c = diffusion_emb.shape diffusion_emb = diffusion_emb.reshape(b*s,c) wavs = wavs.reshape(b*s,1,self.K) diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K) # Recombine diffusion outputs across the sequence into a single prediction. diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added) return diffusion_eps, gate_outputs, alignments class WaveTacotron2(nn.Module): def __init__(self, hparams): super().__init__() self.mask_padding = hparams.mask_padding self.fp16_run = hparams.fp16_run self.n_mel_channels = hparams.n_mel_channels self.n_frames_per_step = hparams.n_frames_per_step self.embedding = nn.Embedding( hparams.n_symbols, hparams.symbols_embedding_dim) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(hparams) self.decoder = WavDecoder(hparams.encoder_embedding_dim) def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask_fill = outputs[0].shape[-1] mask = ~get_mask_from_lengths(output_lengths, mask_fill) mask = mask.expand(mask.size(0), 2, mask.size(1)) outputs[0].data.masked_fill_(mask, 0.0) outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension. outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies return outputs def forward(self, wavs_diffused, wavs_corrected, timesteps, text_inputs, text_lengths, output_lengths): # Squeeze the channel dimension out of the input wavs - we only handle single-channel audio here. wavs_diffused = wavs_diffused.squeeze(dim=1) wavs_corrected = wavs_corrected.squeeze(dim=1) text_lengths, output_lengths = text_lengths.data, output_lengths.data embedded_inputs = self.embedding(text_inputs).transpose(1, 2) encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths) eps_pred, gate_outputs, alignments = self.decoder( wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths) return self.parse_output([eps_pred, gate_outputs, alignments], output_lengths) @register_model def register_diffusion_wavetron(opt_net, opt): hparams = create_hparams() hparams.update(opt_net) hparams = munchify(hparams) return WaveTacotron2(hparams) if __name__ == '__main__': tron = register_diffusion_wavetron({}, {}) out = tron(wavs_diffused=torch.randn(2, 1, 22000), wavs_corrected=torch.randn(2, 1, 22000), timesteps=torch.LongTensor([555, 543]), text_inputs=torch.randint(high=24, size=(2,12)), text_lengths=torch.tensor([12, 12]), output_lengths=torch.tensor([21995])) print([o.shape for o in out])