Add support for a gaussian-diffusion-based wave tacotron

This commit is contained in:
James Betker 2021-07-26 16:27:31 -06:00
parent 97d7cbbc34
commit 96e90e7047
10 changed files with 414 additions and 20 deletions

View File

@ -23,6 +23,8 @@ class TextMelLoader(torch.utils.data.Dataset):
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
self.load_mel_from_disk = hparams.load_mel_from_disk
self.return_wavs = hparams.return_wavs
assert not (self.load_mel_from_disk and self.return_wavs)
self.stft = layers.TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length,
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
@ -47,8 +49,11 @@ class TextMelLoader(torch.utils.data.Dataset):
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = self.stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
if self.return_wavs:
melspec = audio_norm
else:
melspec = self.stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
else:
melspec = torch.from_numpy(np.load(filename))
assert melspec.size(0) == self.stft.n_mel_channels, (
@ -124,13 +129,18 @@ if __name__ == '__main__':
params = {
'mode': 'nv_tacotron',
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
'phase': 'train',
'n_workers': 0,
'batch_size': 2,
'return_wavs': True,
}
from data import create_dataset
ds = create_dataset(params)
from data import create_dataset, create_dataloader
ds, c = create_dataset(params, return_collate=True)
dl = create_dataloader(ds, params, collate_fn=c)
i = 0
m = []
for b in ds:
for b in dl:
m.append(b)
i += 1
if i > 9999:

View File

@ -817,6 +817,81 @@ class GaussianDiffusion:
return terms
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
x_tn1 = self.q_sample(x_start, t-1, noise=noise)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
assert False # not currently supported for this type of diffusion.
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_outputs = model(x_t, x_tn1, self._scale_timesteps(t), **model_kwargs)
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
model_output = terms[gd_out_key]
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert model_output.shape == (B, C, 2, *x_t.shape[2:])
model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0]
x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X:
target = x_start
x_start_pred = model_output
elif self.model_mean_type == ModelMeanType.EPSILON:
target = noise
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
else:
raise NotImplementedError(self.model_mean_type)
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
terms["x_start_predicted"] = x_start_pred
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["mse"]
else:
raise NotImplementedError(self.loss_type)
return terms
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in

View File

@ -95,16 +95,22 @@ class SpacedDiffusion(GaussianDiffusion):
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def autoregressive_training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
def _wrap_model(self, model, autoregressive=False):
if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
return model
return _WrappedModel(
mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
return mod(
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
@ -126,3 +132,18 @@ class _WrappedModel:
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)
class _WrappedAutoregressiveModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, x0, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, x0, new_ts, **kwargs)

View File

@ -441,6 +441,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_raw_y_as_embedding=False,
):
super().__init__()
@ -471,6 +472,8 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert (self.num_classes is not None) != use_raw_y_as_embedding # These are mutually-exclusive.
self.input_blocks = nn.ModuleList(
[
@ -630,16 +633,14 @@ class UNetModel(nn.Module):
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
if self.use_raw_y_as_embedding:
emb = emb + y
h = x.type(self.dtype)
for module in self.input_blocks:

View File

@ -32,7 +32,7 @@ class LocationLayer(nn.Module):
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size):
attention_location_n_filters=32, attention_location_kernel_size=31):
super(Attention, self).__init__()
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
bias=False, w_init_gain='tanh')
@ -528,6 +528,9 @@ def register_nv_tacotron2(opt_net, opt):
if __name__ == '__main__':
tron = register_nv_tacotron2({}, {})
inputs = torch.randint(high=24, size=(1,12)), torch.tensor([12]), torch.randn((1,80,749)), 800, torch.tensor([749])
out = tron(inputs)
inputs = torch.randint(high=24, size=(1,12)), \
torch.tensor([12]), \
torch.randn((1,80,749)), \
torch.tensor([749])
out = tron(*inputs)
print(out)

View File

@ -0,0 +1,230 @@
from math import sqrt
import torch
from munch import munchify
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from models.diffusion.unet_diffusion import UNetModel
from models.tacotron2.layers import ConvNorm, LinearNorm
from models.tacotron2.hparams import create_hparams
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
from trainer.networks import register_model
from models.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import opt_get, checkpoint
class WavDecoder(nn.Module):
def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1):
super().__init__()
self.dec_channels = dec_channels
self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults
self.clarifier = UNetModel(image_size=self.K,
in_channels=1,
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
out_channels=2, # 2 channels: eps_pred and variance_pred
num_res_blocks=2,
attention_resolutions=(16,32),
dims=1,
dropout=.1,
channel_mult=(1,1,1,2,4,8),
use_raw_y_as_embedding=True)
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels])
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels)
self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid')
self.dropout_probability = dropout_probability
def chunk_wav(self, wav):
wavs = list(torch.split(wav, self.K, dim=-1))
# Pad the last chunk as needed.
padding_needed = self.K - wavs[-1].shape[-1]
if padding_needed > 0:
wavs[-1] = F.pad(wavs[-1], (0,padding_needed))
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
return wavs, padding_needed
def prepare_decoder_inputs(self, inp):
# inp.shape = (b,s,K) chunked waveform.
b,s,K = inp.shape
first_frame = torch.zeros(b,1,K).to(inp.device)
x = torch.cat([first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing.
return x
def initialize_decoder_states(self, memory, mask):
""" Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
PARAMS
------
memory: Encoder outputs
mask: Mask for padded data if training, expects None for inference
"""
B = memory.size(0)
MAX_TIME = memory.size(1)
self.attention_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
self.attention_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
self.decoder_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
self.decoder_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(B, self.dec_channels).zero_())
self.memory = memory
self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory)
self.mask = mask
def teardown_states(self):
self.attention_hidden = None
self.attention_cell = None
self.decoder_hidden = None
self.decoder_cell = None
self.attention_weights = None
self.attention_weights_cum = None
self.attention_context = None
self.memory = None
self.processed_memory = None
def produce_context(self, decoder_input):
""" Produces a context and a stop token prediction using the built-in RNN.
PARAMS
------
decoder_input: prior diffusion step that has been resolved.
RETURNS
-------
mel_output:
gate_output: gate output energies
attention_weights:
"""
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability, self.training)
attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory,
self.processed_memory, attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability, self.training)
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)
decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
def recombine(self, diffusion_eps, gate_outputs, alignments, padding_added):
# (T_out, B) -> (B, T_out)
alignments = torch.stack(alignments, dim=1).repeat(1, self.K, 1)
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K)
b,s,_,K = diffusion_eps.shape
# (B, S, 2, K) -> (B, 2, S*K)
diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K)
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths):
'''
Performs a training forward pass with the given data.
:param wav: (b,n) diffused waveform tensor on the interval [-1,1]
:param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over <wav>
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
'''
# Start by splitting up the provided waveforms into discrete segments.
wavs, padding_added = self.chunk_wav(wav)
wavs_corrected, _ = self.chunk_wav(wav_corrected)
wavs_corrected = self.prepare_decoder_inputs(wavs_corrected)
wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected)
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
decoder_contexts, gate_outputs, alignments = [], [], []
while len(decoder_contexts) < wavs_corrected.size(1):
decoder_input = wavs_corrected[:, len(decoder_contexts)]
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
decoder_contexts += [dec_context.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights]
self.teardown_states()
# diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension
diffusion_emb = torch.stack(decoder_contexts, dim=1)
b,s,c = diffusion_emb.shape
diffusion_emb = diffusion_emb.reshape(b*s,c)
wavs = wavs.reshape(b*s,1,self.K)
diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
# Recombine diffusion outputs across the sequence into a single prediction.
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
return diffusion_eps, gate_outputs, alignments
class WaveTacotron2(nn.Module):
def __init__(self, hparams):
super().__init__()
self.mask_padding = hparams.mask_padding
self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(hparams)
self.decoder = WavDecoder(hparams.encoder_embedding_dim)
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask_fill = outputs[0].shape[-1]
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
mask = mask.expand(mask.size(0), 2, mask.size(1))
outputs[0].data.masked_fill_(mask, 0.0)
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies
return outputs
def forward(self, wavs_diffused, wavs_corrected, timesteps, text_inputs, text_lengths, output_lengths):
# Squeeze the channel dimension out of the input wavs - we only handle single-channel audio here.
wavs_diffused = wavs_diffused.squeeze(dim=1)
wavs_corrected = wavs_corrected.squeeze(dim=1)
text_lengths, output_lengths = text_lengths.data, output_lengths.data
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths)
eps_pred, gate_outputs, alignments = self.decoder(
wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths)
return self.parse_output([eps_pred, gate_outputs, alignments], output_lengths)
@register_model
def register_diffusion_wavetron(opt_net, opt):
hparams = create_hparams()
hparams.update(opt_net)
hparams = munchify(hparams)
return WaveTacotron2(hparams)
if __name__ == '__main__':
tron = register_diffusion_wavetron({}, {})
out = tron(wavs_diffused=torch.randn(2, 1, 22000),
wavs_corrected=torch.randn(2, 1, 22000),
timesteps=torch.LongTensor([555, 543]),
text_inputs=torch.randint(high=24, size=(2,12)),
text_lengths=torch.tensor([12, 12]),
output_lengths=torch.tensor([21995]))
print([o.shape for o in out])

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_xform_audio_lj.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wave_tacotron_diffusion_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -10,6 +10,15 @@ from utils.util import opt_get
from utils.weight_scheduler import get_scheduler_for_opt
class SqueezeInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.dim = opt['dim']
def forward(self, state):
return {self.output: state[self.input].squeeze(dim=self.dim)}
# Uses a generator to synthesize an image from [in] and injects the results into [out]
# Note that results are *not* detached.
class GeneratorInjector(Injector):

View File

@ -35,6 +35,39 @@ class GaussianDiffusionInjector(Injector):
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
class AutoregressiveGaussianDiffusionInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.generator = opt['generator']
self.output_variational_bounds_key = opt['out_key_vb_loss']
self.output_x_start_key = opt['out_key_x_start']
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
[opt['beta_schedule']['num_diffusion_timesteps']])
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
self.model_output_keys = opt['model_output_keys']
self.model_eps_pred_key = opt['prediction_key']
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
hq = state[self.input]
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.autoregressive_training_losses(gen, hq, t, self.model_output_keys,
self.model_eps_pred_key,
model_kwargs=model_inputs)
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
outputs = {k: diffusion_outputs[k] for k in self.model_output_keys}
outputs.update({self.output: diffusion_outputs['mse'],
self.output_variational_bounds_key: diffusion_outputs['vb'],
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
return outputs
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
class GaussianDiffusionInferenceInjector(Injector):
def __init__(self, opt, env):

View File

@ -121,7 +121,13 @@ class CrossEntropy(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.opt = opt
self.ce = nn.CrossEntropyLoss()
self.subtype = opt_get(opt, ['subtype'], 'ce')
if self.subtype == 'ce':
self.ce = nn.CrossEntropyLoss()
elif self.subtype == 'bce':
self.ce = nn.BCEWithLogitsLoss()
else:
assert False
def forward(self, _, state):
logits = state[self.opt['logits']]
@ -135,8 +141,14 @@ class CrossEntropy(ConfigurableLoss):
logits = logits * mask
if self.opt['swap_channels']:
logits = logits.permute(0,2,3,1).contiguous()
assert labels.max()+1 <= logits.shape[-1]
return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1))
if self.subtype == 'bce':
logits = logits.reshape(-1, 1)
labels = labels.reshape(-1, 1)
else:
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
assert labels.max()+1 <= logits.shape[-1]
return self.ce(logits, labels)
class PixLoss(ConfigurableLoss):