More work on wave-diffusion

This commit is contained in:
James Betker 2021-07-27 05:36:17 -06:00
parent 49e3b310ea
commit 398185e109
4 changed files with 49 additions and 28 deletions

View File

@ -395,11 +395,11 @@ class ConvGnLelu(nn.Module):
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvGnSilu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d):
super(ConvGnSilu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if norm:
self.gn = nn.GroupNorm(num_groups, filters_out)
else:
@ -411,7 +411,7 @@ class ConvGnSilu(nn.Module):
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
if isinstance(m, convnd):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:

View File

@ -835,14 +835,11 @@ class GaussianDiffusion:
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
x_tn1 = self.q_sample(x_start, t-1, noise=noise)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
assert False # not currently supported for this type of diffusion.
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_outputs = model(x_t, x_tn1, self._scale_timesteps(t), **model_kwargs)
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
model_output = terms[gd_out_key]
if self.model_var_type in [

View File

@ -47,7 +47,7 @@ class AttentionPool2d(nn.Module):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
@ -98,7 +98,12 @@ class Upsample(nn.Module):
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
ksize = 3
pad = 1
if dims == 1:
ksize = 5
pad = 2
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
def forward(self, x):
assert x.shape[1] == self.channels
@ -106,6 +111,8 @@ class Upsample(nn.Module):
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
elif self.dims == 1:
x = F.interpolate(x, scale_factor=4, mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
@ -129,10 +136,19 @@ class Downsample(nn.Module):
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
ksize = 3
pad = 1
if dims == 1:
stride = 4
ksize = 5
pad = 2
elif dims == 2:
stride = 2
else:
stride = (1,2,2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad
)
else:
assert self.channels == self.out_channels

View File

@ -3,9 +3,10 @@ import torch
from munch import munchify
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from torch.nn import functional as F, Flatten
from models.diffusion.unet_diffusion import UNetModel
from models.arch_util import ConvGnSilu
from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d
from models.tacotron2.layers import ConvNorm, LinearNorm
from models.tacotron2.hparams import create_hparams
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
@ -14,6 +15,7 @@ from models.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import opt_get, checkpoint
class WavDecoder(nn.Module):
def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
super().__init__()
@ -24,13 +26,18 @@ class WavDecoder(nn.Module):
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
out_channels=2, # 2 channels: eps_pred and variance_pred
num_res_blocks=2,
attention_resolutions=(16,32),
attention_resolutions=(8,),
dims=1,
dropout=.1,
channel_mult=(1,1,1,2,4,8),
channel_mult=(1,2,4,8),
use_raw_y_as_embedding=True)
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels])
self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d),
ConvGnSilu(32,64,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(64,128,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(128,256,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(256,dec_channels,kernel_size=1,convnd=nn.Conv1d),
AttentionPool2d(self.K//64,dec_channels,dec_channels//4))
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
@ -47,7 +54,7 @@ class WavDecoder(nn.Module):
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
return wavs, padding_needed
def prepare_decoder_inputs(self, inp):
# inp.shape = (b,s,K) chunked waveform.
b,s,K = inp.shape
@ -135,24 +142,25 @@ class WavDecoder(nn.Module):
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths):
def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths):
'''
Performs a training forward pass with the given data.
:param wav: (b,n) diffused waveform tensor on the interval [-1,1]
:param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over <wav>
:param wav_noised: (b,n) diffused waveform tensor on the interval [-1,1]
:param wav_real: (b,n) actual waveform tensor
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
'''
# Start by splitting up the provided waveforms into discrete segments.
wavs, padding_added = self.chunk_wav(wav)
wavs_corrected, _ = self.chunk_wav(wav_corrected)
wavs_corrected = self.prepare_decoder_inputs(wavs_corrected)
wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected)
wav_noised, padding_added = self.chunk_wav(wav_noised)
wav_real, _ = self.chunk_wav(wav_real)
wav_real = self.prepare_decoder_inputs(wav_real)
b,s,K = wav_real.shape
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels)
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
decoder_contexts, gate_outputs, alignments = [], [], []
while len(decoder_contexts) < wavs_corrected.size(1):
decoder_input = wavs_corrected[:, len(decoder_contexts)]
while len(decoder_contexts) < wav_real.size(1):
decoder_input = wav_real[:, len(decoder_contexts)]
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
decoder_contexts += [dec_context.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
@ -163,8 +171,8 @@ class WavDecoder(nn.Module):
diffusion_emb = torch.stack(decoder_contexts, dim=1)
b,s,c = diffusion_emb.shape
diffusion_emb = diffusion_emb.reshape(b*s,c)
wavs = wavs.reshape(b*s,1,self.K)
diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
wav_noised = wav_noised.reshape(b*s,1,self.K)
diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
# Recombine diffusion outputs across the sequence into a single prediction.
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
return diffusion_eps, gate_outputs, alignments