forked from mrq/DL-Art-School
Add support for a gaussian-diffusion-based wave tacotron
This commit is contained in:
parent
97d7cbbc34
commit
96e90e7047
|
@ -23,6 +23,8 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.load_mel_from_disk = hparams.load_mel_from_disk
|
||||
self.return_wavs = hparams.return_wavs
|
||||
assert not (self.load_mel_from_disk and self.return_wavs)
|
||||
self.stft = layers.TacotronSTFT(
|
||||
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
||||
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
|
||||
|
@ -47,6 +49,9 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
audio_norm = audio / self.max_wav_value
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
||||
if self.return_wavs:
|
||||
melspec = audio_norm
|
||||
else:
|
||||
melspec = self.stft.mel_spectrogram(audio_norm)
|
||||
melspec = torch.squeeze(melspec, 0)
|
||||
else:
|
||||
|
@ -124,13 +129,18 @@ if __name__ == '__main__':
|
|||
params = {
|
||||
'mode': 'nv_tacotron',
|
||||
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
||||
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': 2,
|
||||
'return_wavs': True,
|
||||
}
|
||||
from data import create_dataset
|
||||
ds = create_dataset(params)
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
ds, c = create_dataset(params, return_collate=True)
|
||||
dl = create_dataloader(ds, params, collate_fn=c)
|
||||
i = 0
|
||||
m = []
|
||||
for b in ds:
|
||||
for b in dl:
|
||||
m.append(b)
|
||||
i += 1
|
||||
if i > 9999:
|
||||
|
|
|
@ -817,6 +817,81 @@ class GaussianDiffusion:
|
|||
|
||||
return terms
|
||||
|
||||
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
|
||||
:param model: the model to evaluate loss on.
|
||||
:param x_start: the [N x C x ...] tensor of inputs.
|
||||
:param t: a batch of timestep indices.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:param noise: if specified, the specific Gaussian noise to try to remove.
|
||||
:return: a dict with the key "loss" containing a tensor of shape [N].
|
||||
Some mean or variance settings may also have other keys.
|
||||
"""
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
if noise is None:
|
||||
noise = th.randn_like(x_start)
|
||||
x_t = self.q_sample(x_start, t, noise=noise)
|
||||
x_tn1 = self.q_sample(x_start, t-1, noise=noise)
|
||||
|
||||
terms = {}
|
||||
|
||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||
assert False # not currently supported for this type of diffusion.
|
||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||
model_outputs = model(x_t, x_tn1, self._scale_timesteps(t), **model_kwargs)
|
||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
||||
model_output = terms[gd_out_key]
|
||||
if self.model_var_type in [
|
||||
ModelVarType.LEARNED,
|
||||
ModelVarType.LEARNED_RANGE,
|
||||
]:
|
||||
B, C = x_t.shape[:2]
|
||||
assert model_output.shape == (B, C, 2, *x_t.shape[2:])
|
||||
model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
|
||||
# Learn the variance using the variational bound, but don't let
|
||||
# it affect our mean prediction.
|
||||
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
||||
terms["vb"] = self._vb_terms_bpd(
|
||||
model=lambda *args, r=frozen_out: r,
|
||||
x_start=x_start,
|
||||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
)["output"]
|
||||
if self.loss_type == LossType.RESCALED_MSE:
|
||||
# Divide by 1000 for equivalence with initial implementation.
|
||||
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
||||
terms["vb"] *= self.num_timesteps / 1000.0
|
||||
|
||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||
target = self.q_posterior_mean_variance(
|
||||
x_start=x_start, x_t=x_t, t=t
|
||||
)[0]
|
||||
x_start_pred = torch.zeros(x_start) # Not supported.
|
||||
elif self.model_mean_type == ModelMeanType.START_X:
|
||||
target = x_start
|
||||
x_start_pred = model_output
|
||||
elif self.model_mean_type == ModelMeanType.EPSILON:
|
||||
target = noise
|
||||
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
|
||||
else:
|
||||
raise NotImplementedError(self.model_mean_type)
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
||||
terms["x_start_predicted"] = x_start_pred
|
||||
if "vb" in terms:
|
||||
terms["loss"] = terms["mse"] + terms["vb"]
|
||||
else:
|
||||
terms["loss"] = terms["mse"]
|
||||
else:
|
||||
raise NotImplementedError(self.loss_type)
|
||||
|
||||
return terms
|
||||
|
||||
def _prior_bpd(self, x_start):
|
||||
"""
|
||||
Get the prior KL term for the variational lower-bound, measured in
|
||||
|
|
|
@ -95,16 +95,22 @@ class SpacedDiffusion(GaussianDiffusion):
|
|||
): # pylint: disable=signature-differs
|
||||
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def autoregressive_training_losses(
|
||||
self, model, *args, **kwargs
|
||||
): # pylint: disable=signature-differs
|
||||
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
|
||||
|
||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
||||
|
||||
def condition_score(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
||||
|
||||
def _wrap_model(self, model):
|
||||
if isinstance(model, _WrappedModel):
|
||||
def _wrap_model(self, model, autoregressive=False):
|
||||
if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
|
||||
return model
|
||||
return _WrappedModel(
|
||||
mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
|
||||
return mod(
|
||||
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
|
||||
)
|
||||
|
||||
|
@ -126,3 +132,18 @@ class _WrappedModel:
|
|||
if self.rescale_timesteps:
|
||||
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
return self.model(x, new_ts, **kwargs)
|
||||
|
||||
|
||||
class _WrappedAutoregressiveModel:
|
||||
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
||||
self.model = model
|
||||
self.timestep_map = timestep_map
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.original_num_steps = original_num_steps
|
||||
|
||||
def __call__(self, x, x0, ts, **kwargs):
|
||||
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
||||
new_ts = map_tensor[ts]
|
||||
if self.rescale_timesteps:
|
||||
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
return self.model(x, x0, new_ts, **kwargs)
|
|
@ -441,6 +441,7 @@ class UNetModel(nn.Module):
|
|||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_raw_y_as_embedding=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -471,6 +472,8 @@ class UNetModel(nn.Module):
|
|||
|
||||
if self.num_classes is not None:
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
||||
assert (self.num_classes is not None) != use_raw_y_as_embedding # These are mutually-exclusive.
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
|
@ -630,16 +633,14 @@ class UNetModel(nn.Module):
|
|||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape == (x.shape[0],)
|
||||
emb = emb + self.label_emb(y)
|
||||
if self.use_raw_y_as_embedding:
|
||||
emb = emb + y
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
|
|
|
@ -32,7 +32,7 @@ class LocationLayer(nn.Module):
|
|||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
attention_location_n_filters, attention_location_kernel_size):
|
||||
attention_location_n_filters=32, attention_location_kernel_size=31):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
|
||||
bias=False, w_init_gain='tanh')
|
||||
|
@ -528,6 +528,9 @@ def register_nv_tacotron2(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
tron = register_nv_tacotron2({}, {})
|
||||
inputs = torch.randint(high=24, size=(1,12)), torch.tensor([12]), torch.randn((1,80,749)), 800, torch.tensor([749])
|
||||
out = tron(inputs)
|
||||
inputs = torch.randint(high=24, size=(1,12)), \
|
||||
torch.tensor([12]), \
|
||||
torch.randn((1,80,749)), \
|
||||
torch.tensor([749])
|
||||
out = tron(*inputs)
|
||||
print(out)
|
230
codes/models/tacotron2/wave_tacotron.py
Normal file
230
codes/models/tacotron2/wave_tacotron.py
Normal file
|
@ -0,0 +1,230 @@
|
|||
from math import sqrt
|
||||
import torch
|
||||
from munch import munchify
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.diffusion.unet_diffusion import UNetModel
|
||||
from models.tacotron2.layers import ConvNorm, LinearNorm
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
|
||||
from trainer.networks import register_model
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from utils.util import opt_get, checkpoint
|
||||
|
||||
|
||||
class WavDecoder(nn.Module):
|
||||
def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1):
|
||||
super().__init__()
|
||||
self.dec_channels = dec_channels
|
||||
self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults
|
||||
self.clarifier = UNetModel(image_size=self.K,
|
||||
in_channels=1,
|
||||
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
||||
out_channels=2, # 2 channels: eps_pred and variance_pred
|
||||
num_res_blocks=2,
|
||||
attention_resolutions=(16,32),
|
||||
dims=1,
|
||||
dropout=.1,
|
||||
channel_mult=(1,1,1,2,4,8),
|
||||
use_raw_y_as_embedding=True)
|
||||
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
|
||||
self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels])
|
||||
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
|
||||
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
|
||||
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
|
||||
self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels)
|
||||
self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid')
|
||||
self.dropout_probability = dropout_probability
|
||||
|
||||
def chunk_wav(self, wav):
|
||||
wavs = list(torch.split(wav, self.K, dim=-1))
|
||||
# Pad the last chunk as needed.
|
||||
padding_needed = self.K - wavs[-1].shape[-1]
|
||||
if padding_needed > 0:
|
||||
wavs[-1] = F.pad(wavs[-1], (0,padding_needed))
|
||||
|
||||
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
|
||||
return wavs, padding_needed
|
||||
|
||||
def prepare_decoder_inputs(self, inp):
|
||||
# inp.shape = (b,s,K) chunked waveform.
|
||||
b,s,K = inp.shape
|
||||
first_frame = torch.zeros(b,1,K).to(inp.device)
|
||||
x = torch.cat([first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing.
|
||||
return x
|
||||
|
||||
def initialize_decoder_states(self, memory, mask):
|
||||
""" Initializes attention rnn states, decoder rnn states, attention
|
||||
weights, attention cumulative weights, attention context, stores memory
|
||||
and stores processed memory
|
||||
PARAMS
|
||||
------
|
||||
memory: Encoder outputs
|
||||
mask: Mask for padded data if training, expects None for inference
|
||||
"""
|
||||
B = memory.size(0)
|
||||
MAX_TIME = memory.size(1)
|
||||
|
||||
self.attention_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
self.attention_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
|
||||
self.decoder_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
self.decoder_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
|
||||
self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_())
|
||||
self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_())
|
||||
self.attention_context = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
|
||||
self.memory = memory
|
||||
self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory)
|
||||
self.mask = mask
|
||||
|
||||
def teardown_states(self):
|
||||
self.attention_hidden = None
|
||||
self.attention_cell = None
|
||||
self.decoder_hidden = None
|
||||
self.decoder_cell = None
|
||||
self.attention_weights = None
|
||||
self.attention_weights_cum = None
|
||||
self.attention_context = None
|
||||
self.memory = None
|
||||
self.processed_memory = None
|
||||
|
||||
def produce_context(self, decoder_input):
|
||||
""" Produces a context and a stop token prediction using the built-in RNN.
|
||||
PARAMS
|
||||
------
|
||||
decoder_input: prior diffusion step that has been resolved.
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
mel_output:
|
||||
gate_output: gate output energies
|
||||
attention_weights:
|
||||
"""
|
||||
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability, self.training)
|
||||
|
||||
attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)
|
||||
self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory,
|
||||
self.processed_memory, attention_weights_cat, self.mask)
|
||||
|
||||
self.attention_weights_cum += self.attention_weights
|
||||
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability, self.training)
|
||||
|
||||
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)
|
||||
decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context)
|
||||
|
||||
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
|
||||
return decoder_output, gate_prediction, self.attention_weights
|
||||
|
||||
def recombine(self, diffusion_eps, gate_outputs, alignments, padding_added):
|
||||
# (T_out, B) -> (B, T_out)
|
||||
alignments = torch.stack(alignments, dim=1).repeat(1, self.K, 1)
|
||||
# (T_out, B) -> (B, T_out)
|
||||
gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K)
|
||||
|
||||
b,s,_,K = diffusion_eps.shape
|
||||
# (B, S, 2, K) -> (B, 2, S*K)
|
||||
diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K)
|
||||
|
||||
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
|
||||
|
||||
def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths):
|
||||
'''
|
||||
Performs a training forward pass with the given data.
|
||||
:param wav: (b,n) diffused waveform tensor on the interval [-1,1]
|
||||
:param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over <wav>
|
||||
:param text_enc: (b,e) embedding post-encoder with e=self.dec_channels
|
||||
'''
|
||||
|
||||
# Start by splitting up the provided waveforms into discrete segments.
|
||||
wavs, padding_added = self.chunk_wav(wav)
|
||||
wavs_corrected, _ = self.chunk_wav(wav_corrected)
|
||||
wavs_corrected = self.prepare_decoder_inputs(wavs_corrected)
|
||||
wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected)
|
||||
|
||||
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
|
||||
decoder_contexts, gate_outputs, alignments = [], [], []
|
||||
while len(decoder_contexts) < wavs_corrected.size(1):
|
||||
decoder_input = wavs_corrected[:, len(decoder_contexts)]
|
||||
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
|
||||
decoder_contexts += [dec_context.squeeze(1)]
|
||||
gate_outputs += [gate_output.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
self.teardown_states()
|
||||
|
||||
# diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension
|
||||
diffusion_emb = torch.stack(decoder_contexts, dim=1)
|
||||
b,s,c = diffusion_emb.shape
|
||||
diffusion_emb = diffusion_emb.reshape(b*s,c)
|
||||
wavs = wavs.reshape(b*s,1,self.K)
|
||||
diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
|
||||
# Recombine diffusion outputs across the sequence into a single prediction.
|
||||
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
|
||||
return diffusion_eps, gate_outputs, alignments
|
||||
|
||||
|
||||
class WaveTacotron2(nn.Module):
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
self.mask_padding = hparams.mask_padding
|
||||
self.fp16_run = hparams.fp16_run
|
||||
self.n_mel_channels = hparams.n_mel_channels
|
||||
self.n_frames_per_step = hparams.n_frames_per_step
|
||||
self.embedding = nn.Embedding(
|
||||
hparams.n_symbols, hparams.symbols_embedding_dim)
|
||||
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
self.encoder = Encoder(hparams)
|
||||
self.decoder = WavDecoder(hparams.encoder_embedding_dim)
|
||||
|
||||
def parse_output(self, outputs, output_lengths=None):
|
||||
if self.mask_padding and output_lengths is not None:
|
||||
mask_fill = outputs[0].shape[-1]
|
||||
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
|
||||
mask = mask.expand(mask.size(0), 2, mask.size(1))
|
||||
|
||||
outputs[0].data.masked_fill_(mask, 0.0)
|
||||
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
|
||||
outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies
|
||||
|
||||
return outputs
|
||||
|
||||
def forward(self, wavs_diffused, wavs_corrected, timesteps, text_inputs, text_lengths, output_lengths):
|
||||
# Squeeze the channel dimension out of the input wavs - we only handle single-channel audio here.
|
||||
wavs_diffused = wavs_diffused.squeeze(dim=1)
|
||||
wavs_corrected = wavs_corrected.squeeze(dim=1)
|
||||
|
||||
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
||||
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
||||
encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths)
|
||||
eps_pred, gate_outputs, alignments = self.decoder(
|
||||
wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths)
|
||||
|
||||
return self.parse_output([eps_pred, gate_outputs, alignments], output_lengths)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_diffusion_wavetron(opt_net, opt):
|
||||
hparams = create_hparams()
|
||||
hparams.update(opt_net)
|
||||
hparams = munchify(hparams)
|
||||
return WaveTacotron2(hparams)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tron = register_diffusion_wavetron({}, {})
|
||||
out = tron(wavs_diffused=torch.randn(2, 1, 22000),
|
||||
wavs_corrected=torch.randn(2, 1, 22000),
|
||||
timesteps=torch.LongTensor([555, 543]),
|
||||
text_inputs=torch.randint(high=24, size=(2,12)),
|
||||
text_lengths=torch.tensor([12, 12]),
|
||||
output_lengths=torch.tensor([21995]))
|
||||
print([o.shape for o in out])
|
|
@ -300,7 +300,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_xform_audio_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wave_tacotron_diffusion_lj.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -10,6 +10,15 @@ from utils.util import opt_get
|
|||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
|
||||
|
||||
class SqueezeInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.dim = opt['dim']
|
||||
|
||||
def forward(self, state):
|
||||
return {self.output: state[self.input].squeeze(dim=self.dim)}
|
||||
|
||||
|
||||
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
||||
# Note that results are *not* detached.
|
||||
class GeneratorInjector(Injector):
|
||||
|
|
|
@ -35,6 +35,39 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
|
||||
|
||||
|
||||
class AutoregressiveGaussianDiffusionInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.generator = opt['generator']
|
||||
self.output_variational_bounds_key = opt['out_key_vb_loss']
|
||||
self.output_x_start_key = opt['out_key_x_start']
|
||||
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
||||
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
|
||||
[opt['beta_schedule']['num_diffusion_timesteps']])
|
||||
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
||||
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||
self.model_output_keys = opt['model_output_keys']
|
||||
self.model_eps_pred_key = opt['prediction_key']
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
hq = state[self.input]
|
||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
||||
diffusion_outputs = self.diffusion.autoregressive_training_losses(gen, hq, t, self.model_output_keys,
|
||||
self.model_eps_pred_key,
|
||||
model_kwargs=model_inputs)
|
||||
if isinstance(self.schedule_sampler, LossAwareSampler):
|
||||
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||
outputs = {k: diffusion_outputs[k] for k in self.model_output_keys}
|
||||
outputs.update({self.output: diffusion_outputs['mse'],
|
||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
||||
class GaussianDiffusionInferenceInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
|
|
|
@ -121,7 +121,13 @@ class CrossEntropy(ConfigurableLoss):
|
|||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.subtype = opt_get(opt, ['subtype'], 'ce')
|
||||
if self.subtype == 'ce':
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
elif self.subtype == 'bce':
|
||||
self.ce = nn.BCEWithLogitsLoss()
|
||||
else:
|
||||
assert False
|
||||
|
||||
def forward(self, _, state):
|
||||
logits = state[self.opt['logits']]
|
||||
|
@ -135,8 +141,14 @@ class CrossEntropy(ConfigurableLoss):
|
|||
logits = logits * mask
|
||||
if self.opt['swap_channels']:
|
||||
logits = logits.permute(0,2,3,1).contiguous()
|
||||
if self.subtype == 'bce':
|
||||
logits = logits.reshape(-1, 1)
|
||||
labels = labels.reshape(-1, 1)
|
||||
else:
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
labels = labels.view(-1)
|
||||
assert labels.max()+1 <= logits.shape[-1]
|
||||
return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
return self.ce(logits, labels)
|
||||
|
||||
|
||||
class PixLoss(ConfigurableLoss):
|
||||
|
|
Loading…
Reference in New Issue
Block a user