forked from mrq/DL-Art-School
238 lines
12 KiB
238 lines
12 KiB
from math import sqrt
import torch
from munch import munchify
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F, Flatten
from models.arch_util import ConvGnSilu
from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d
from models.tacotron2.layers import ConvNorm, LinearNorm
from models.tacotron2.hparams import create_hparams
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
from trainer.networks import register_model
from models.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import opt_get, checkpoint
class WavDecoder(nn.Module):
def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
self.dec_channels = dec_channels
self.K = int(sample_rate * (K_ms/1000))
self.clarifier = UNetModel(image_size=self.K,
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
out_channels=2, # 2 channels: eps_pred and variance_pred
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d),
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels)
self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid')
self.dropout_probability = dropout_probability
def chunk_wav(self, wav):
wavs = list(torch.split(wav, self.K, dim=-1))
# Pad the last chunk as needed.
padding_needed = self.K - wavs[-1].shape[-1]
if padding_needed > 0:
wavs[-1] = F.pad(wavs[-1], (0,padding_needed))
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
return wavs, padding_needed
def prepare_decoder_inputs(self, inp):
# inp.shape = (b,s,K) chunked waveform.
b,s,K = inp.shape
first_frame = torch.zeros(b,1,K).to(inp.device)
x =[first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing.
return x
def initialize_decoder_states(self, memory, mask):
""" Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
memory: Encoder outputs
mask: Mask for padded data if training, expects None for inference
B = memory.size(0)
MAX_TIME = memory.size(1)
self.attention_hidden = Variable(, self.dec_channels).zero_())
self.attention_cell = Variable(, self.dec_channels).zero_())
self.decoder_hidden = Variable(, self.dec_channels).zero_())
self.decoder_cell = Variable(, self.dec_channels).zero_())
self.attention_weights = Variable(, MAX_TIME).zero_())
self.attention_weights_cum = Variable(, MAX_TIME).zero_())
self.attention_context = Variable(, self.dec_channels).zero_())
self.memory = memory
self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory)
self.mask = mask
def teardown_states(self):
self.attention_hidden = None
self.attention_cell = None
self.decoder_hidden = None
self.decoder_cell = None
self.attention_weights = None
self.attention_weights_cum = None
self.attention_context = None
self.memory = None
self.processed_memory = None
def produce_context(self, decoder_input):
""" Produces a context and a stop token prediction using the built-in RNN.
decoder_input: prior diffusion step that has been resolved.
gate_output: gate output energies
cell_input =, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability,
attention_weights_cat =, self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory,
self.processed_memory, attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
decoder_input =, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability,
decoder_hidden_attention_context =, self.attention_context), dim=1)
decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
def recombine(self, diffusion_eps, gate_outputs, alignments, padding_added):
# (T_out, B) -> (B, T_out)
alignments = torch.stack(alignments, dim=1).repeat(1, self.K, 1)
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K)
b,s,_,K = diffusion_eps.shape
# (B, S, 2, K) -> (B, 2, S*K)
diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K)
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths):
Performs a training forward pass with the given data.
:param wav_noised: (b,n) diffused waveform tensor on the interval [-1,1]
:param wav_real: (b,n) actual waveform tensor
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
# Start by splitting up the provided waveforms into discrete segments.
wav_noised, padding_added = self.chunk_wav(wav_noised)
wav_real, _ = self.chunk_wav(wav_real)
wav_real = self.prepare_decoder_inputs(wav_real)
b,s,K = wav_real.shape
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels)
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
decoder_contexts, gate_outputs, alignments = [], [], []
while len(decoder_contexts) < wav_real.size(1):
decoder_input = wav_real[:, len(decoder_contexts)]
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
decoder_contexts += [dec_context.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights]
# diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension
diffusion_emb = torch.stack(decoder_contexts, dim=1)
b,s,c = diffusion_emb.shape
diffusion_emb = diffusion_emb.reshape(b*s,c)
wav_noised = wav_noised.reshape(b*s,1,self.K)
diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
# Recombine diffusion outputs across the sequence into a single prediction.
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
return diffusion_eps, gate_outputs, alignments
class WaveTacotron2(nn.Module):
def __init__(self, hparams):
self.mask_padding = hparams.mask_padding
self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std
|, val)
self.encoder = Encoder(hparams)
self.decoder = WavDecoder(hparams.encoder_embedding_dim)
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask_fill = outputs[0].shape[-1]
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
mask = mask.unsqueeze(1).repeat(1,2,1)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies
return outputs
def forward(self, wavs_diffused, wavs_corrected, timesteps, text_inputs, text_lengths, output_lengths):
# Squeeze the channel dimension out of the input wavs - we only handle single-channel audio here.
wavs_diffused = wavs_diffused.squeeze(dim=1)
wavs_corrected = wavs_corrected.squeeze(dim=1)
text_lengths, output_lengths =,
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths)
eps_pred, gate_outputs, alignments = self.decoder(
wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths)
return self.parse_output([eps_pred, gate_outputs, alignments], output_lengths)
def register_diffusion_wavetron(opt_net, opt):
hparams = create_hparams()
hparams = munchify(hparams)
return WaveTacotron2(hparams)
if __name__ == '__main__':
tron = register_diffusion_wavetron({}, {})
out = tron(wavs_diffused=torch.randn(2, 1, 22000),
wavs_corrected=torch.randn(2, 1, 22000),
timesteps=torch.LongTensor([555, 543]),
text_inputs=torch.randint(high=24, size=(2,12)),
text_lengths=torch.tensor([12, 12]),
print([o.shape for o in out]) |