More work on wave-diffusion

This commit is contained in:
James Betker 2021-07-27 05:36:17 -06:00
parent 49e3b310ea
commit 398185e109
4 changed files with 49 additions and 28 deletions

View File

@ -395,11 +395,11 @@ class ConvGnLelu(nn.Module):
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes. ''' kernel sizes. '''
class ConvGnSilu(nn.Module): class ConvGnSilu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d):
super(ConvGnSilu, self).__init__() super(ConvGnSilu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3} padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys() assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if norm: if norm:
self.gn = nn.GroupNorm(num_groups, filters_out) self.gn = nn.GroupNorm(num_groups, filters_out)
else: else:
@ -411,7 +411,7 @@ class ConvGnSilu(nn.Module):
# Init params. # Init params.
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, convnd):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
m.weight.data *= weight_init_factor m.weight.data *= weight_init_factor
if m.bias is not None: if m.bias is not None:

View File

@ -835,14 +835,11 @@ class GaussianDiffusion:
if noise is None: if noise is None:
noise = th.randn_like(x_start) noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise) x_t = self.q_sample(x_start, t, noise=noise)
x_tn1 = self.q_sample(x_start, t-1, noise=noise)
terms = {} terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
assert False # not currently supported for this type of diffusion. assert False # not currently supported for this type of diffusion.
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_outputs = model(x_t, x_tn1, self._scale_timesteps(t), **model_kwargs) model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
model_output = terms[gd_out_key] model_output = terms[gd_out_key]
if self.model_var_type in [ if self.model_var_type in [

View File

@ -47,7 +47,7 @@ class AttentionPool2d(nn.Module):
b, c, *_spatial = x.shape b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW) x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x) x = self.qkv_proj(x)
x = self.attention(x) x = self.attention(x)
x = self.c_proj(x) x = self.c_proj(x)
@ -98,7 +98,12 @@ class Upsample(nn.Module):
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) ksize = 3
pad = 1
if dims == 1:
ksize = 5
pad = 2
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
@ -106,6 +111,8 @@ class Upsample(nn.Module):
x = F.interpolate( x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
) )
elif self.dims == 1:
x = F.interpolate(x, scale_factor=4, mode="nearest")
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv: if self.use_conv:
@ -129,10 +136,19 @@ class Downsample(nn.Module):
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2) ksize = 3
pad = 1
if dims == 1:
stride = 4
ksize = 5
pad = 2
elif dims == 2:
stride = 2
else:
stride = (1,2,2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=1 dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad
) )
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels

View File

@ -3,9 +3,10 @@ import torch
from munch import munchify from munch import munchify
from torch.autograd import Variable from torch.autograd import Variable
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F, Flatten
from models.diffusion.unet_diffusion import UNetModel from models.arch_util import ConvGnSilu
from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d
from models.tacotron2.layers import ConvNorm, LinearNorm from models.tacotron2.layers import ConvNorm, LinearNorm
from models.tacotron2.hparams import create_hparams from models.tacotron2.hparams import create_hparams
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
@ -14,6 +15,7 @@ from models.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import opt_get, checkpoint from utils.util import opt_get, checkpoint
class WavDecoder(nn.Module): class WavDecoder(nn.Module):
def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1): def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
super().__init__() super().__init__()
@ -24,13 +26,18 @@ class WavDecoder(nn.Module):
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model. model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
out_channels=2, # 2 channels: eps_pred and variance_pred out_channels=2, # 2 channels: eps_pred and variance_pred
num_res_blocks=2, num_res_blocks=2,
attention_resolutions=(16,32), attention_resolutions=(8,),
dims=1, dims=1,
dropout=.1, dropout=.1,
channel_mult=(1,1,1,2,4,8), channel_mult=(1,2,4,8),
use_raw_y_as_embedding=True) use_raw_y_as_embedding=True)
assert self.K % 64 == 0 # Otherwise the UNetModel breaks. assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels]) self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d),
ConvGnSilu(32,64,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(64,128,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(128,256,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(256,dec_channels,kernel_size=1,convnd=nn.Conv1d),
AttentionPool2d(self.K//64,dec_channels,dec_channels//4))
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels) self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels) self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1) self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
@ -47,7 +54,7 @@ class WavDecoder(nn.Module):
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
return wavs, padding_needed return wavs, padding_needed
def prepare_decoder_inputs(self, inp): def prepare_decoder_inputs(self, inp):
# inp.shape = (b,s,K) chunked waveform. # inp.shape = (b,s,K) chunked waveform.
b,s,K = inp.shape b,s,K = inp.shape
@ -135,24 +142,25 @@ class WavDecoder(nn.Module):
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added] return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths): def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths):
''' '''
Performs a training forward pass with the given data. Performs a training forward pass with the given data.
:param wav: (b,n) diffused waveform tensor on the interval [-1,1] :param wav_noised: (b,n) diffused waveform tensor on the interval [-1,1]
:param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over <wav> :param wav_real: (b,n) actual waveform tensor
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels :param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
''' '''
# Start by splitting up the provided waveforms into discrete segments. # Start by splitting up the provided waveforms into discrete segments.
wavs, padding_added = self.chunk_wav(wav) wav_noised, padding_added = self.chunk_wav(wav_noised)
wavs_corrected, _ = self.chunk_wav(wav_corrected) wav_real, _ = self.chunk_wav(wav_real)
wavs_corrected = self.prepare_decoder_inputs(wavs_corrected) wav_real = self.prepare_decoder_inputs(wav_real)
wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected) b,s,K = wav_real.shape
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels)
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths)) self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
decoder_contexts, gate_outputs, alignments = [], [], [] decoder_contexts, gate_outputs, alignments = [], [], []
while len(decoder_contexts) < wavs_corrected.size(1): while len(decoder_contexts) < wav_real.size(1):
decoder_input = wavs_corrected[:, len(decoder_contexts)] decoder_input = wav_real[:, len(decoder_contexts)]
dec_context, gate_output, attention_weights = self.produce_context(decoder_input) dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
decoder_contexts += [dec_context.squeeze(1)] decoder_contexts += [dec_context.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)] gate_outputs += [gate_output.squeeze(1)]
@ -163,8 +171,8 @@ class WavDecoder(nn.Module):
diffusion_emb = torch.stack(decoder_contexts, dim=1) diffusion_emb = torch.stack(decoder_contexts, dim=1)
b,s,c = diffusion_emb.shape b,s,c = diffusion_emb.shape
diffusion_emb = diffusion_emb.reshape(b*s,c) diffusion_emb = diffusion_emb.reshape(b*s,c)
wavs = wavs.reshape(b*s,1,self.K) wav_noised = wav_noised.reshape(b*s,1,self.K)
diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K) diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
# Recombine diffusion outputs across the sequence into a single prediction. # Recombine diffusion outputs across the sequence into a single prediction.
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added) diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
return diffusion_eps, gate_outputs, alignments return diffusion_eps, gate_outputs, alignments