More work on wave-diffusion
This commit is contained in:
parent
49e3b310ea
commit
398185e109
|
@ -395,11 +395,11 @@ class ConvGnLelu(nn.Module):
|
|||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvGnSilu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d):
|
||||
super(ConvGnSilu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
if norm:
|
||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||
else:
|
||||
|
@ -411,7 +411,7 @@ class ConvGnSilu(nn.Module):
|
|||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if isinstance(m, convnd):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||
m.weight.data *= weight_init_factor
|
||||
if m.bias is not None:
|
||||
|
|
|
@ -835,14 +835,11 @@ class GaussianDiffusion:
|
|||
if noise is None:
|
||||
noise = th.randn_like(x_start)
|
||||
x_t = self.q_sample(x_start, t, noise=noise)
|
||||
x_tn1 = self.q_sample(x_start, t-1, noise=noise)
|
||||
|
||||
terms = {}
|
||||
|
||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||
assert False # not currently supported for this type of diffusion.
|
||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||
model_outputs = model(x_t, x_tn1, self._scale_timesteps(t), **model_kwargs)
|
||||
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
||||
model_output = terms[gd_out_key]
|
||||
if self.model_var_type in [
|
||||
|
|
|
@ -47,7 +47,7 @@ class AttentionPool2d(nn.Module):
|
|||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
|
@ -98,7 +98,12 @@ class Upsample(nn.Module):
|
|||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
ksize = 3
|
||||
pad = 1
|
||||
if dims == 1:
|
||||
ksize = 5
|
||||
pad = 2
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
|
@ -106,6 +111,8 @@ class Upsample(nn.Module):
|
|||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
elif self.dims == 1:
|
||||
x = F.interpolate(x, scale_factor=4, mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
|
@ -129,10 +136,19 @@ class Downsample(nn.Module):
|
|||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
ksize = 3
|
||||
pad = 1
|
||||
if dims == 1:
|
||||
stride = 4
|
||||
ksize = 5
|
||||
pad = 2
|
||||
elif dims == 2:
|
||||
stride = 2
|
||||
else:
|
||||
stride = (1,2,2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
||||
dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
|
|
|
@ -3,9 +3,10 @@ import torch
|
|||
from munch import munchify
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import functional as F, Flatten
|
||||
|
||||
from models.diffusion.unet_diffusion import UNetModel
|
||||
from models.arch_util import ConvGnSilu
|
||||
from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d
|
||||
from models.tacotron2.layers import ConvNorm, LinearNorm
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
|
||||
|
@ -14,6 +15,7 @@ from models.tacotron2.taco_utils import get_mask_from_lengths
|
|||
from utils.util import opt_get, checkpoint
|
||||
|
||||
|
||||
|
||||
class WavDecoder(nn.Module):
|
||||
def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
|
||||
super().__init__()
|
||||
|
@ -24,13 +26,18 @@ class WavDecoder(nn.Module):
|
|||
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
||||
out_channels=2, # 2 channels: eps_pred and variance_pred
|
||||
num_res_blocks=2,
|
||||
attention_resolutions=(16,32),
|
||||
attention_resolutions=(8,),
|
||||
dims=1,
|
||||
dropout=.1,
|
||||
channel_mult=(1,1,1,2,4,8),
|
||||
channel_mult=(1,2,4,8),
|
||||
use_raw_y_as_embedding=True)
|
||||
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
|
||||
self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels])
|
||||
self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d),
|
||||
ConvGnSilu(32,64,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||
ConvGnSilu(64,128,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||
ConvGnSilu(128,256,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||
ConvGnSilu(256,dec_channels,kernel_size=1,convnd=nn.Conv1d),
|
||||
AttentionPool2d(self.K//64,dec_channels,dec_channels//4))
|
||||
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
|
||||
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
|
||||
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
|
||||
|
@ -47,7 +54,7 @@ class WavDecoder(nn.Module):
|
|||
|
||||
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
|
||||
return wavs, padding_needed
|
||||
|
||||
|
||||
def prepare_decoder_inputs(self, inp):
|
||||
# inp.shape = (b,s,K) chunked waveform.
|
||||
b,s,K = inp.shape
|
||||
|
@ -135,24 +142,25 @@ class WavDecoder(nn.Module):
|
|||
|
||||
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
|
||||
|
||||
def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths):
|
||||
def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths):
|
||||
'''
|
||||
Performs a training forward pass with the given data.
|
||||
:param wav: (b,n) diffused waveform tensor on the interval [-1,1]
|
||||
:param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over <wav>
|
||||
:param wav_noised: (b,n) diffused waveform tensor on the interval [-1,1]
|
||||
:param wav_real: (b,n) actual waveform tensor
|
||||
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
|
||||
'''
|
||||
|
||||
# Start by splitting up the provided waveforms into discrete segments.
|
||||
wavs, padding_added = self.chunk_wav(wav)
|
||||
wavs_corrected, _ = self.chunk_wav(wav_corrected)
|
||||
wavs_corrected = self.prepare_decoder_inputs(wavs_corrected)
|
||||
wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected)
|
||||
wav_noised, padding_added = self.chunk_wav(wav_noised)
|
||||
wav_real, _ = self.chunk_wav(wav_real)
|
||||
wav_real = self.prepare_decoder_inputs(wav_real)
|
||||
b,s,K = wav_real.shape
|
||||
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels)
|
||||
|
||||
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
|
||||
decoder_contexts, gate_outputs, alignments = [], [], []
|
||||
while len(decoder_contexts) < wavs_corrected.size(1):
|
||||
decoder_input = wavs_corrected[:, len(decoder_contexts)]
|
||||
while len(decoder_contexts) < wav_real.size(1):
|
||||
decoder_input = wav_real[:, len(decoder_contexts)]
|
||||
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
|
||||
decoder_contexts += [dec_context.squeeze(1)]
|
||||
gate_outputs += [gate_output.squeeze(1)]
|
||||
|
@ -163,8 +171,8 @@ class WavDecoder(nn.Module):
|
|||
diffusion_emb = torch.stack(decoder_contexts, dim=1)
|
||||
b,s,c = diffusion_emb.shape
|
||||
diffusion_emb = diffusion_emb.reshape(b*s,c)
|
||||
wavs = wavs.reshape(b*s,1,self.K)
|
||||
diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
|
||||
wav_noised = wav_noised.reshape(b*s,1,self.K)
|
||||
diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
|
||||
# Recombine diffusion outputs across the sequence into a single prediction.
|
||||
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
|
||||
return diffusion_eps, gate_outputs, alignments
|
||||
|
|
Loading…
Reference in New Issue
Block a user