DL-Art-School/codes/models/tacotron2/wave_tacotron.py

230 lines
11 KiB
Python
Raw Normal View History

from math import sqrt
import torch
from munch import munchify
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from models.diffusion.unet_diffusion import UNetModel
from models.tacotron2.layers import ConvNorm, LinearNorm
from models.tacotron2.hparams import create_hparams
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
from trainer.networks import register_model
from models.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import opt_get, checkpoint
class WavDecoder(nn.Module):
def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1):
super().__init__()
self.dec_channels = dec_channels
self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults
self.clarifier = UNetModel(image_size=self.K,
in_channels=1,
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
out_channels=2, # 2 channels: eps_pred and variance_pred
num_res_blocks=2,
attention_resolutions=(16,32),
dims=1,
dropout=.1,
channel_mult=(1,1,1,2,4,8),
use_raw_y_as_embedding=True)
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels])
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels)
self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid')
self.dropout_probability = dropout_probability
def chunk_wav(self, wav):
wavs = list(torch.split(wav, self.K, dim=-1))
# Pad the last chunk as needed.
padding_needed = self.K - wavs[-1].shape[-1]
if padding_needed > 0:
wavs[-1] = F.pad(wavs[-1], (0,padding_needed))
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
return wavs, padding_needed
def prepare_decoder_inputs(self, inp):
# inp.shape = (b,s,K) chunked waveform.
b,s,K = inp.shape
first_frame = torch.zeros(b,1,K).to(inp.device)
x = torch.cat([first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing.
return x
def initialize_decoder_states(self, memory, mask):
""" Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
PARAMS
------
memory: Encoder outputs
mask: Mask for padded data if training, expects None for inference
"""
B = memory.size(0)
MAX_TIME = memory.size(1)
self.attention_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
self.attention_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
self.decoder_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
self.decoder_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(B, self.dec_channels).zero_())
self.memory = memory
self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory)
self.mask = mask
def teardown_states(self):
self.attention_hidden = None
self.attention_cell = None
self.decoder_hidden = None
self.decoder_cell = None
self.attention_weights = None
self.attention_weights_cum = None
self.attention_context = None
self.memory = None
self.processed_memory = None
def produce_context(self, decoder_input):
""" Produces a context and a stop token prediction using the built-in RNN.
PARAMS
------
decoder_input: prior diffusion step that has been resolved.
RETURNS
-------
mel_output:
gate_output: gate output energies
attention_weights:
"""
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability, self.training)
attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory,
self.processed_memory, attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability, self.training)
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)
decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
def recombine(self, diffusion_eps, gate_outputs, alignments, padding_added):
# (T_out, B) -> (B, T_out)
alignments = torch.stack(alignments, dim=1).repeat(1, self.K, 1)
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K)
b,s,_,K = diffusion_eps.shape
# (B, S, 2, K) -> (B, 2, S*K)
diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K)
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths):
'''
Performs a training forward pass with the given data.
:param wav: (b,n) diffused waveform tensor on the interval [-1,1]
:param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over <wav>
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
'''
# Start by splitting up the provided waveforms into discrete segments.
wavs, padding_added = self.chunk_wav(wav)
wavs_corrected, _ = self.chunk_wav(wav_corrected)
wavs_corrected = self.prepare_decoder_inputs(wavs_corrected)
wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected)
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
decoder_contexts, gate_outputs, alignments = [], [], []
while len(decoder_contexts) < wavs_corrected.size(1):
decoder_input = wavs_corrected[:, len(decoder_contexts)]
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
decoder_contexts += [dec_context.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights]
self.teardown_states()
# diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension
diffusion_emb = torch.stack(decoder_contexts, dim=1)
b,s,c = diffusion_emb.shape
diffusion_emb = diffusion_emb.reshape(b*s,c)
wavs = wavs.reshape(b*s,1,self.K)
diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
# Recombine diffusion outputs across the sequence into a single prediction.
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
return diffusion_eps, gate_outputs, alignments
class WaveTacotron2(nn.Module):
def __init__(self, hparams):
super().__init__()
self.mask_padding = hparams.mask_padding
self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(hparams)
self.decoder = WavDecoder(hparams.encoder_embedding_dim)
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask_fill = outputs[0].shape[-1]
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
mask = mask.expand(mask.size(0), 2, mask.size(1))
outputs[0].data.masked_fill_(mask, 0.0)
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies
return outputs
def forward(self, wavs_diffused, wavs_corrected, timesteps, text_inputs, text_lengths, output_lengths):
# Squeeze the channel dimension out of the input wavs - we only handle single-channel audio here.
wavs_diffused = wavs_diffused.squeeze(dim=1)
wavs_corrected = wavs_corrected.squeeze(dim=1)
text_lengths, output_lengths = text_lengths.data, output_lengths.data
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths)
eps_pred, gate_outputs, alignments = self.decoder(
wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths)
return self.parse_output([eps_pred, gate_outputs, alignments], output_lengths)
@register_model
def register_diffusion_wavetron(opt_net, opt):
hparams = create_hparams()
hparams.update(opt_net)
hparams = munchify(hparams)
return WaveTacotron2(hparams)
if __name__ == '__main__':
tron = register_diffusion_wavetron({}, {})
out = tron(wavs_diffused=torch.randn(2, 1, 22000),
wavs_corrected=torch.randn(2, 1, 22000),
timesteps=torch.LongTensor([555, 543]),
text_inputs=torch.randint(high=24, size=(2,12)),
text_lengths=torch.tensor([12, 12]),
output_lengths=torch.tensor([21995]))
print([o.shape for o in out])