diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 73a3b63d..997b925f 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -395,11 +395,11 @@ class ConvGnLelu(nn.Module): ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' class ConvGnSilu(nn.Module): - def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d): super(ConvGnSilu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) if norm: self.gn = nn.GroupNorm(num_groups, filters_out) else: @@ -411,7 +411,7 @@ class ConvGnSilu(nn.Module): # Init params. for m in self.modules(): - if isinstance(m, nn.Conv2d): + if isinstance(m, convnd): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear') m.weight.data *= weight_init_factor if m.bias is not None: diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 0f2cb8c5..ac2f3d80 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -835,14 +835,11 @@ class GaussianDiffusion: if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) - x_tn1 = self.q_sample(x_start, t-1, noise=noise) - terms = {} - if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: assert False # not currently supported for this type of diffusion. elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_outputs = model(x_t, x_tn1, self._scale_timesteps(t), **model_kwargs) + model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) model_output = terms[gd_out_key] if self.model_var_type in [ diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index 55c5dca8..83f6c5f5 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -47,7 +47,7 @@ class AttentionPool2d(nn.Module): b, c, *_spatial = x.shape x = x.reshape(b, c, -1) # NC(HW) x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) @@ -98,7 +98,12 @@ class Upsample(nn.Module): self.use_conv = use_conv self.dims = dims if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + ksize = 3 + pad = 1 + if dims == 1: + ksize = 5 + pad = 2 + self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad) def forward(self, x): assert x.shape[1] == self.channels @@ -106,6 +111,8 @@ class Upsample(nn.Module): x = F.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" ) + elif self.dims == 1: + x = F.interpolate(x, scale_factor=4, mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: @@ -129,10 +136,19 @@ class Downsample(nn.Module): self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) + ksize = 3 + pad = 1 + if dims == 1: + stride = 4 + ksize = 5 + pad = 2 + elif dims == 2: + stride = 2 + else: + stride = (1,2,2) if use_conv: self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=1 + dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad ) else: assert self.channels == self.out_channels diff --git a/codes/models/tacotron2/wave_tacotron.py b/codes/models/tacotron2/wave_tacotron.py index 99bb4e28..7a227529 100644 --- a/codes/models/tacotron2/wave_tacotron.py +++ b/codes/models/tacotron2/wave_tacotron.py @@ -3,9 +3,10 @@ import torch from munch import munchify from torch.autograd import Variable from torch import nn -from torch.nn import functional as F +from torch.nn import functional as F, Flatten -from models.diffusion.unet_diffusion import UNetModel +from models.arch_util import ConvGnSilu +from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d from models.tacotron2.layers import ConvNorm, LinearNorm from models.tacotron2.hparams import create_hparams from models.tacotron2.tacotron2 import Prenet, Attention, Encoder @@ -14,6 +15,7 @@ from models.tacotron2.taco_utils import get_mask_from_lengths from utils.util import opt_get, checkpoint + class WavDecoder(nn.Module): def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1): super().__init__() @@ -24,13 +26,18 @@ class WavDecoder(nn.Module): model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model. out_channels=2, # 2 channels: eps_pred and variance_pred num_res_blocks=2, - attention_resolutions=(16,32), + attention_resolutions=(8,), dims=1, dropout=.1, - channel_mult=(1,1,1,2,4,8), + channel_mult=(1,2,4,8), use_raw_y_as_embedding=True) assert self.K % 64 == 0 # Otherwise the UNetModel breaks. - self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels]) + self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d), + ConvGnSilu(32,64,kernel_size=5,stride=4,convnd=nn.Conv1d), + ConvGnSilu(64,128,kernel_size=5,stride=4,convnd=nn.Conv1d), + ConvGnSilu(128,256,kernel_size=5,stride=4,convnd=nn.Conv1d), + ConvGnSilu(256,dec_channels,kernel_size=1,convnd=nn.Conv1d), + AttentionPool2d(self.K//64,dec_channels,dec_channels//4)) self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels) self.attention_layer = Attention(dec_channels, dec_channels, dec_channels) self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1) @@ -47,7 +54,7 @@ class WavDecoder(nn.Module): wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length return wavs, padding_needed - + def prepare_decoder_inputs(self, inp): # inp.shape = (b,s,K) chunked waveform. b,s,K = inp.shape @@ -135,24 +142,25 @@ class WavDecoder(nn.Module): return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added] - def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths): + def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths): ''' Performs a training forward pass with the given data. - :param wav: (b,n) diffused waveform tensor on the interval [-1,1] - :param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over + :param wav_noised: (b,n) diffused waveform tensor on the interval [-1,1] + :param wav_real: (b,n) actual waveform tensor :param text_enc: (b,e) embedding post-encoder with e=self.dec_channels ''' # Start by splitting up the provided waveforms into discrete segments. - wavs, padding_added = self.chunk_wav(wav) - wavs_corrected, _ = self.chunk_wav(wav_corrected) - wavs_corrected = self.prepare_decoder_inputs(wavs_corrected) - wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected) + wav_noised, padding_added = self.chunk_wav(wav_noised) + wav_real, _ = self.chunk_wav(wav_real) + wav_real = self.prepare_decoder_inputs(wav_real) + b,s,K = wav_real.shape + wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels) self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths)) decoder_contexts, gate_outputs, alignments = [], [], [] - while len(decoder_contexts) < wavs_corrected.size(1): - decoder_input = wavs_corrected[:, len(decoder_contexts)] + while len(decoder_contexts) < wav_real.size(1): + decoder_input = wav_real[:, len(decoder_contexts)] dec_context, gate_output, attention_weights = self.produce_context(decoder_input) decoder_contexts += [dec_context.squeeze(1)] gate_outputs += [gate_output.squeeze(1)] @@ -163,8 +171,8 @@ class WavDecoder(nn.Module): diffusion_emb = torch.stack(decoder_contexts, dim=1) b,s,c = diffusion_emb.shape diffusion_emb = diffusion_emb.reshape(b*s,c) - wavs = wavs.reshape(b*s,1,self.K) - diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K) + wav_noised = wav_noised.reshape(b*s,1,self.K) + diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K) # Recombine diffusion outputs across the sequence into a single prediction. diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added) return diffusion_eps, gate_outputs, alignments