tacotron2, ready for prime time!

This commit is contained in:
James Betker 2021-07-08 22:13:44 -06:00
parent 86fd3ad7fd
commit 1ff434218e
11 changed files with 108 additions and 37 deletions

View File

@ -67,8 +67,8 @@ def create_dataset(dataset_opt, return_collate=False):
from data.audio.nv_tacotron_dataset import TextMelCollate as C
from models.tacotron2.hparams import create_hparams
default_params = create_hparams()
dataset_opt.update(default_params)
dataset_opt = munchify(dataset_opt)
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
collate = C(dataset_opt.n_frames_per_step)
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))

View File

@ -111,8 +111,13 @@ class TextMelCollate():
gate_padded[i, mel.size(1)-1:] = 1
output_lengths[i] = mel.size(1)
return text_padded, input_lengths, mel_padded, gate_padded, \
output_lengths
return {
'padded_text': text_padded,
'input_lengths': input_lengths,
'padded_mel': mel_padded,
'padded_gate': gate_padded,
'output_lengths': output_lengths
}
if __name__ == '__main__':

View File

View File

@ -1,9 +1,34 @@
from torch import nn
from trainer.losses import ConfigurableLoss
class Tacotron2Loss(nn.Module):
class Tacotron2Loss(ConfigurableLoss):
def __init__(self, opt_loss, env):
super().__init__(opt_loss, env)
self.mel_target_key = opt_loss['mel_target_key']
self.mel_output_key = opt_loss['mel_output_key']
self.mel_output_postnet_key = opt_loss['mel_output_postnet_key']
self.gate_target_key = opt_loss['gate_target_key']
self.gate_output_key = opt_loss['gate_output_key']
def forward(self, _, state):
mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key]
mel_target.requires_grad = False
gate_target.requires_grad = False
gate_target = gate_target.view(-1, 1)
mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[self.mel_output_postnet_key], state[self.gate_output_key]
gate_out = gate_out.view(-1, 1)
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
nn.MSELoss()(mel_out_postnet, mel_target)
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
return mel_loss + gate_loss
class Tacotron2LossRaw(nn.Module):
def __init__(self):
super(Tacotron2Loss, self).__init__()
super().__init__()
def forward(self, model_output, targets):
mel_target, gate_target = targets[0], targets[1]
@ -16,4 +41,4 @@ class Tacotron2Loss(nn.Module):
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
nn.MSELoss()(mel_out_postnet, mel_target)
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
return mel_loss + gate_loss
return mel_loss + gate_loss

View File

@ -3,9 +3,10 @@ from scipy.io.wavfile import read
import torch
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len, device=lengths.device))
def get_mask_from_lengths(lengths, max_len=None):
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
mask = (ids < lengths.unsqueeze(1)).bool()
return mask

View File

@ -4,11 +4,11 @@ from munch import munchify
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from layers import ConvNorm, LinearNorm
from models.tacotron2.layers import ConvNorm, LinearNorm
from models.tacotron2.hparams import create_hparams
from trainer.networks import register_model
from taco_utils import to_gpu, get_mask_from_lengths
from utils.util import opt_get
from models.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import opt_get, checkpoint
class LocationLayer(nn.Module):
@ -74,7 +74,7 @@ class Attention(nn.Module):
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
attention_weights_cat: previous and cumulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(
@ -408,8 +408,7 @@ class Decoder(nn.Module):
mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
mel_output, gate_output, attention_weights = self.decode(
decoder_input)
mel_output, gate_output, attention_weights = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights]
@ -474,23 +473,10 @@ class Tacotron2(nn.Module):
self.decoder = Decoder(hparams)
self.postnet = Postnet(hparams)
def parse_batch(self, batch):
text_padded, input_lengths, mel_padded, gate_padded, \
output_lengths = batch
text_padded = to_gpu(text_padded).long()
input_lengths = to_gpu(input_lengths).long()
max_len = torch.max(input_lengths.data).item()
mel_padded = to_gpu(mel_padded).float()
gate_padded = to_gpu(gate_padded).float()
output_lengths = to_gpu(output_lengths).long()
return (
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
(mel_padded, gate_padded))
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths)
mask_fill = outputs[0].shape[-1]
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
mask = mask.permute(1, 0, 2)
@ -500,8 +486,7 @@ class Tacotron2(nn.Module):
return outputs
def forward(self, inputs):
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
def forward(self, text_inputs, text_lengths, mels, output_lengths):
text_lengths, output_lengths = text_lengths.data, output_lengths.data
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
@ -535,9 +520,8 @@ class Tacotron2(nn.Module):
@register_model
def register_nv_tacotron2(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
hparams = create_hparams()
hparams.update(kw)
hparams.update(opt_net)
hparams = munchify(hparams)
return Tacotron2(hparams)

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_latent_unet_diffusion_sm.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -275,6 +275,10 @@ class ExtensibleTrainer(BaseModel):
# Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
def fix_image(img):
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
img = img.unsqueeze(dim=1)
# Normalize so spectrogram is easier to view.
img = (img - img.mean()) / img.std()
if img.shape[1] > 3:
img = img[:, :3, :, :]
if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):

View File

@ -0,0 +1,44 @@
import torch
import trainer.eval.evaluator as evaluator
from data import create_dataset
from data.audio.nv_tacotron_dataset import TextMelCollate
from models.tacotron2.loss import Tacotron2LossRaw
from torch.utils.data import DataLoader
from tqdm import tqdm
# Evaluates the performance of a MEL spectrogram predictor.
class MelEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env, uses_all_ddp=True)
self.batch_sz = opt_eval['batch_size']
self.dataset = create_dataset(opt_eval['dataset'])
assert self.batch_sz is not None
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1, collate_fn=TextMelCollate(n_frames_per_step=1))
self.criterion = Tacotron2LossRaw()
def perform_eval(self):
counter = 0
total_error = 0
self.model.eval()
for batch in tqdm(self.dataloader):
model_params = {
'text_inputs': 'padded_text',
'text_lengths': 'input_lengths',
'mels': 'padded_mel',
'output_lengths': 'output_lengths',
}
params = {k: batch[v].to(self.env['device']) for k, v in model_params.items()}
with torch.no_grad():
pred = self.model(**params)
targets = ['padded_mel', 'padded_gate']
targets = [batch[t].to(self.env['device']) for t in targets]
total_error += self.criterion(pred, targets).item()
counter += 1
self.model.train()
return {"validation-score": total_error / counter}

View File

@ -58,7 +58,7 @@ def create_loss(opt_loss, env):
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
elif type == 'nv_tacotron2_loss':
from models.tacotron2.loss import Tacotron2Loss
return Tacotron2Loss()
return Tacotron2Loss(opt_loss, env)
else:
raise NotImplementedError

View File

@ -27,6 +27,7 @@ class ConfigurableStep(Module):
self.scaler = GradScaler(enabled=self.opt['fp16'])
self.grads_generated = False
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)
# This is a half-measure that can be used between anomaly_detection and running a potentially problematic
# trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips
@ -267,6 +268,13 @@ class ConfigurableStep(Module):
else:
self.nan_counter = 0
if self.clip_grad_eps is not None:
for pg in opt.param_groups:
grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps)
if torch.isnan(grad_norm):
nan_found = True
self.nan_counter += 1
if not nan_found:
self.scaler.step(opt)
self.scaler.update()