More work on wave-diffusion
This commit is contained in:
parent
49e3b310ea
commit
398185e109
|
@ -395,11 +395,11 @@ class ConvGnLelu(nn.Module):
|
||||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvGnSilu(nn.Module):
|
class ConvGnSilu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d):
|
||||||
super(ConvGnSilu, self).__init__()
|
super(ConvGnSilu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
if norm:
|
if norm:
|
||||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||||
else:
|
else:
|
||||||
|
@ -411,7 +411,7 @@ class ConvGnSilu(nn.Module):
|
||||||
|
|
||||||
# Init params.
|
# Init params.
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, convnd):
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||||
m.weight.data *= weight_init_factor
|
m.weight.data *= weight_init_factor
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
|
|
|
@ -835,14 +835,11 @@ class GaussianDiffusion:
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = th.randn_like(x_start)
|
noise = th.randn_like(x_start)
|
||||||
x_t = self.q_sample(x_start, t, noise=noise)
|
x_t = self.q_sample(x_start, t, noise=noise)
|
||||||
x_tn1 = self.q_sample(x_start, t-1, noise=noise)
|
|
||||||
|
|
||||||
terms = {}
|
terms = {}
|
||||||
|
|
||||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||||
assert False # not currently supported for this type of diffusion.
|
assert False # not currently supported for this type of diffusion.
|
||||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||||
model_outputs = model(x_t, x_tn1, self._scale_timesteps(t), **model_kwargs)
|
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
||||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
||||||
model_output = terms[gd_out_key]
|
model_output = terms[gd_out_key]
|
||||||
if self.model_var_type in [
|
if self.model_var_type in [
|
||||||
|
|
|
@ -47,7 +47,7 @@ class AttentionPool2d(nn.Module):
|
||||||
b, c, *_spatial = x.shape
|
b, c, *_spatial = x.shape
|
||||||
x = x.reshape(b, c, -1) # NC(HW)
|
x = x.reshape(b, c, -1) # NC(HW)
|
||||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1)
|
||||||
x = self.qkv_proj(x)
|
x = self.qkv_proj(x)
|
||||||
x = self.attention(x)
|
x = self.attention(x)
|
||||||
x = self.c_proj(x)
|
x = self.c_proj(x)
|
||||||
|
@ -98,7 +98,12 @@ class Upsample(nn.Module):
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
ksize = 3
|
||||||
|
pad = 1
|
||||||
|
if dims == 1:
|
||||||
|
ksize = 5
|
||||||
|
pad = 2
|
||||||
|
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
|
@ -106,6 +111,8 @@ class Upsample(nn.Module):
|
||||||
x = F.interpolate(
|
x = F.interpolate(
|
||||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||||
)
|
)
|
||||||
|
elif self.dims == 1:
|
||||||
|
x = F.interpolate(x, scale_factor=4, mode="nearest")
|
||||||
else:
|
else:
|
||||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
if self.use_conv:
|
if self.use_conv:
|
||||||
|
@ -129,10 +136,19 @@ class Downsample(nn.Module):
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
stride = 2 if dims != 3 else (1, 2, 2)
|
ksize = 3
|
||||||
|
pad = 1
|
||||||
|
if dims == 1:
|
||||||
|
stride = 4
|
||||||
|
ksize = 5
|
||||||
|
pad = 2
|
||||||
|
elif dims == 2:
|
||||||
|
stride = 2
|
||||||
|
else:
|
||||||
|
stride = (1,2,2)
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.op = conv_nd(
|
self.op = conv_nd(
|
||||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
|
|
|
@ -3,9 +3,10 @@ import torch
|
||||||
from munch import munchify
|
from munch import munchify
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F, Flatten
|
||||||
|
|
||||||
from models.diffusion.unet_diffusion import UNetModel
|
from models.arch_util import ConvGnSilu
|
||||||
|
from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d
|
||||||
from models.tacotron2.layers import ConvNorm, LinearNorm
|
from models.tacotron2.layers import ConvNorm, LinearNorm
|
||||||
from models.tacotron2.hparams import create_hparams
|
from models.tacotron2.hparams import create_hparams
|
||||||
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
|
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
|
||||||
|
@ -14,6 +15,7 @@ from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||||
from utils.util import opt_get, checkpoint
|
from utils.util import opt_get, checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WavDecoder(nn.Module):
|
class WavDecoder(nn.Module):
|
||||||
def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
|
def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -24,13 +26,18 @@ class WavDecoder(nn.Module):
|
||||||
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
||||||
out_channels=2, # 2 channels: eps_pred and variance_pred
|
out_channels=2, # 2 channels: eps_pred and variance_pred
|
||||||
num_res_blocks=2,
|
num_res_blocks=2,
|
||||||
attention_resolutions=(16,32),
|
attention_resolutions=(8,),
|
||||||
dims=1,
|
dims=1,
|
||||||
dropout=.1,
|
dropout=.1,
|
||||||
channel_mult=(1,1,1,2,4,8),
|
channel_mult=(1,2,4,8),
|
||||||
use_raw_y_as_embedding=True)
|
use_raw_y_as_embedding=True)
|
||||||
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
|
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
|
||||||
self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels])
|
self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d),
|
||||||
|
ConvGnSilu(32,64,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||||
|
ConvGnSilu(64,128,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||||
|
ConvGnSilu(128,256,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||||
|
ConvGnSilu(256,dec_channels,kernel_size=1,convnd=nn.Conv1d),
|
||||||
|
AttentionPool2d(self.K//64,dec_channels,dec_channels//4))
|
||||||
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
|
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
|
||||||
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
|
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
|
||||||
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
|
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
|
||||||
|
@ -135,24 +142,25 @@ class WavDecoder(nn.Module):
|
||||||
|
|
||||||
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
|
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
|
||||||
|
|
||||||
def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths):
|
def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths):
|
||||||
'''
|
'''
|
||||||
Performs a training forward pass with the given data.
|
Performs a training forward pass with the given data.
|
||||||
:param wav: (b,n) diffused waveform tensor on the interval [-1,1]
|
:param wav_noised: (b,n) diffused waveform tensor on the interval [-1,1]
|
||||||
:param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over <wav>
|
:param wav_real: (b,n) actual waveform tensor
|
||||||
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
|
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Start by splitting up the provided waveforms into discrete segments.
|
# Start by splitting up the provided waveforms into discrete segments.
|
||||||
wavs, padding_added = self.chunk_wav(wav)
|
wav_noised, padding_added = self.chunk_wav(wav_noised)
|
||||||
wavs_corrected, _ = self.chunk_wav(wav_corrected)
|
wav_real, _ = self.chunk_wav(wav_real)
|
||||||
wavs_corrected = self.prepare_decoder_inputs(wavs_corrected)
|
wav_real = self.prepare_decoder_inputs(wav_real)
|
||||||
wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected)
|
b,s,K = wav_real.shape
|
||||||
|
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels)
|
||||||
|
|
||||||
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
|
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
|
||||||
decoder_contexts, gate_outputs, alignments = [], [], []
|
decoder_contexts, gate_outputs, alignments = [], [], []
|
||||||
while len(decoder_contexts) < wavs_corrected.size(1):
|
while len(decoder_contexts) < wav_real.size(1):
|
||||||
decoder_input = wavs_corrected[:, len(decoder_contexts)]
|
decoder_input = wav_real[:, len(decoder_contexts)]
|
||||||
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
|
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
|
||||||
decoder_contexts += [dec_context.squeeze(1)]
|
decoder_contexts += [dec_context.squeeze(1)]
|
||||||
gate_outputs += [gate_output.squeeze(1)]
|
gate_outputs += [gate_output.squeeze(1)]
|
||||||
|
@ -163,8 +171,8 @@ class WavDecoder(nn.Module):
|
||||||
diffusion_emb = torch.stack(decoder_contexts, dim=1)
|
diffusion_emb = torch.stack(decoder_contexts, dim=1)
|
||||||
b,s,c = diffusion_emb.shape
|
b,s,c = diffusion_emb.shape
|
||||||
diffusion_emb = diffusion_emb.reshape(b*s,c)
|
diffusion_emb = diffusion_emb.reshape(b*s,c)
|
||||||
wavs = wavs.reshape(b*s,1,self.K)
|
wav_noised = wav_noised.reshape(b*s,1,self.K)
|
||||||
diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
|
diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
|
||||||
# Recombine diffusion outputs across the sequence into a single prediction.
|
# Recombine diffusion outputs across the sequence into a single prediction.
|
||||||
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
|
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
|
||||||
return diffusion_eps, gate_outputs, alignments
|
return diffusion_eps, gate_outputs, alignments
|
||||||
|
|
Loading…
Reference in New Issue
Block a user