diff --git a/codes/data/__init__.py b/codes/data/__init__.py index c33279b7..0cd5f8af 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -67,8 +67,8 @@ def create_dataset(dataset_opt, return_collate=False): from data.audio.nv_tacotron_dataset import TextMelCollate as C from models.tacotron2.hparams import create_hparams default_params = create_hparams() - dataset_opt.update(default_params) - dataset_opt = munchify(dataset_opt) + default_params.update(dataset_opt) + dataset_opt = munchify(default_params) collate = C(dataset_opt.n_frames_per_step) else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index e5847911..2545d78b 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -111,8 +111,13 @@ class TextMelCollate(): gate_padded[i, mel.size(1)-1:] = 1 output_lengths[i] = mel.size(1) - return text_padded, input_lengths, mel_padded, gate_padded, \ - output_lengths + return { + 'padded_text': text_padded, + 'input_lengths': input_lengths, + 'padded_mel': mel_padded, + 'padded_gate': gate_padded, + 'output_lengths': output_lengths + } if __name__ == '__main__': diff --git a/codes/models/tacotron2/__init__.py b/codes/models/tacotron2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/tacotron2/loss.py b/codes/models/tacotron2/loss.py index 3938934f..3109e317 100644 --- a/codes/models/tacotron2/loss.py +++ b/codes/models/tacotron2/loss.py @@ -1,9 +1,34 @@ from torch import nn +from trainer.losses import ConfigurableLoss -class Tacotron2Loss(nn.Module): + +class Tacotron2Loss(ConfigurableLoss): + def __init__(self, opt_loss, env): + super().__init__(opt_loss, env) + self.mel_target_key = opt_loss['mel_target_key'] + self.mel_output_key = opt_loss['mel_output_key'] + self.mel_output_postnet_key = opt_loss['mel_output_postnet_key'] + self.gate_target_key = opt_loss['gate_target_key'] + self.gate_output_key = opt_loss['gate_output_key'] + + def forward(self, _, state): + mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key] + mel_target.requires_grad = False + gate_target.requires_grad = False + gate_target = gate_target.view(-1, 1) + + mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[self.mel_output_postnet_key], state[self.gate_output_key] + gate_out = gate_out.view(-1, 1) + mel_loss = nn.MSELoss()(mel_out, mel_target) + \ + nn.MSELoss()(mel_out_postnet, mel_target) + gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) + return mel_loss + gate_loss + + +class Tacotron2LossRaw(nn.Module): def __init__(self): - super(Tacotron2Loss, self).__init__() + super().__init__() def forward(self, model_output, targets): mel_target, gate_target = targets[0], targets[1] @@ -16,4 +41,4 @@ class Tacotron2Loss(nn.Module): mel_loss = nn.MSELoss()(mel_out, mel_target) + \ nn.MSELoss()(mel_out_postnet, mel_target) gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) - return mel_loss + gate_loss \ No newline at end of file + return mel_loss + gate_loss diff --git a/codes/models/tacotron2/taco_utils.py b/codes/models/tacotron2/taco_utils.py index 5c8bb6f6..12ff9783 100644 --- a/codes/models/tacotron2/taco_utils.py +++ b/codes/models/tacotron2/taco_utils.py @@ -3,9 +3,10 @@ from scipy.io.wavfile import read import torch -def get_mask_from_lengths(lengths): - max_len = torch.max(lengths).item() - ids = torch.arange(0, max_len, out=torch.LongTensor(max_len, device=lengths.device)) +def get_mask_from_lengths(lengths, max_len=None): + if max_len is None: + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device) mask = (ids < lengths.unsqueeze(1)).bool() return mask diff --git a/codes/models/tacotron2/tacotron2.py b/codes/models/tacotron2/tacotron2.py index 5e78883d..bed404b2 100644 --- a/codes/models/tacotron2/tacotron2.py +++ b/codes/models/tacotron2/tacotron2.py @@ -4,11 +4,11 @@ from munch import munchify from torch.autograd import Variable from torch import nn from torch.nn import functional as F -from layers import ConvNorm, LinearNorm +from models.tacotron2.layers import ConvNorm, LinearNorm from models.tacotron2.hparams import create_hparams from trainer.networks import register_model -from taco_utils import to_gpu, get_mask_from_lengths -from utils.util import opt_get +from models.tacotron2.taco_utils import get_mask_from_lengths +from utils.util import opt_get, checkpoint class LocationLayer(nn.Module): @@ -74,7 +74,7 @@ class Attention(nn.Module): attention_hidden_state: attention rnn last output memory: encoder outputs processed_memory: processed encoder outputs - attention_weights_cat: previous and cummulative attention weights + attention_weights_cat: previous and cumulative attention weights mask: binary mask for padded data """ alignment = self.get_alignment_energies( @@ -408,8 +408,7 @@ class Decoder(nn.Module): mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] - mel_output, gate_output, attention_weights = self.decode( - decoder_input) + mel_output, gate_output, attention_weights = self.decode(decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze(1)] alignments += [attention_weights] @@ -474,23 +473,10 @@ class Tacotron2(nn.Module): self.decoder = Decoder(hparams) self.postnet = Postnet(hparams) - def parse_batch(self, batch): - text_padded, input_lengths, mel_padded, gate_padded, \ - output_lengths = batch - text_padded = to_gpu(text_padded).long() - input_lengths = to_gpu(input_lengths).long() - max_len = torch.max(input_lengths.data).item() - mel_padded = to_gpu(mel_padded).float() - gate_padded = to_gpu(gate_padded).float() - output_lengths = to_gpu(output_lengths).long() - - return ( - (text_padded, input_lengths, mel_padded, max_len, output_lengths), - (mel_padded, gate_padded)) - def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: - mask = ~get_mask_from_lengths(output_lengths) + mask_fill = outputs[0].shape[-1] + mask = ~get_mask_from_lengths(output_lengths, mask_fill) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) @@ -500,8 +486,7 @@ class Tacotron2(nn.Module): return outputs - def forward(self, inputs): - text_inputs, text_lengths, mels, max_len, output_lengths = inputs + def forward(self, text_inputs, text_lengths, mels, output_lengths): text_lengths, output_lengths = text_lengths.data, output_lengths.data embedded_inputs = self.embedding(text_inputs).transpose(1, 2) @@ -535,9 +520,8 @@ class Tacotron2(nn.Module): @register_model def register_nv_tacotron2(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) hparams = create_hparams() - hparams.update(kw) + hparams.update(opt_net) hparams = munchify(hparams) return Tacotron2(hparams) diff --git a/codes/train.py b/codes/train.py index a0f3b5f3..a2202343 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_latent_unet_diffusion_sm.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 7da22bdd..d88a447e 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -275,6 +275,10 @@ class ExtensibleTrainer(BaseModel): # Record visual outputs for usage in debugging and testing. if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0: def fix_image(img): + if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False): + img = img.unsqueeze(dim=1) + # Normalize so spectrogram is easier to view. + img = (img - img.mean()) / img.std() if img.shape[1] > 3: img = img[:, :3, :, :] if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False): diff --git a/codes/trainer/eval/mel_evaluator.py b/codes/trainer/eval/mel_evaluator.py new file mode 100644 index 00000000..0044c0bf --- /dev/null +++ b/codes/trainer/eval/mel_evaluator.py @@ -0,0 +1,44 @@ +import torch + +import trainer.eval.evaluator as evaluator + +from data import create_dataset +from data.audio.nv_tacotron_dataset import TextMelCollate +from models.tacotron2.loss import Tacotron2LossRaw +from torch.utils.data import DataLoader +from tqdm import tqdm + + +# Evaluates the performance of a MEL spectrogram predictor. +class MelEvaluator(evaluator.Evaluator): + def __init__(self, model, opt_eval, env): + super().__init__(model, opt_eval, env, uses_all_ddp=True) + self.batch_sz = opt_eval['batch_size'] + self.dataset = create_dataset(opt_eval['dataset']) + assert self.batch_sz is not None + self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1, collate_fn=TextMelCollate(n_frames_per_step=1)) + self.criterion = Tacotron2LossRaw() + + def perform_eval(self): + counter = 0 + total_error = 0 + self.model.eval() + for batch in tqdm(self.dataloader): + model_params = { + 'text_inputs': 'padded_text', + 'text_lengths': 'input_lengths', + 'mels': 'padded_mel', + 'output_lengths': 'output_lengths', + } + params = {k: batch[v].to(self.env['device']) for k, v in model_params.items()} + with torch.no_grad(): + pred = self.model(**params) + + targets = ['padded_mel', 'padded_gate'] + targets = [batch[t].to(self.env['device']) for t in targets] + total_error += self.criterion(pred, targets).item() + counter += 1 + self.model.train() + + return {"validation-score": total_error / counter} + diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index 47094e6f..7bd34136 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -58,7 +58,7 @@ def create_loss(opt_loss, env): return SwitchTransformersLoadBalancingLoss(opt_loss, env) elif type == 'nv_tacotron2_loss': from models.tacotron2.loss import Tacotron2Loss - return Tacotron2Loss() + return Tacotron2Loss(opt_loss, env) else: raise NotImplementedError diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index cc50b9c2..e269e1c3 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -27,6 +27,7 @@ class ConfigurableStep(Module): self.scaler = GradScaler(enabled=self.opt['fp16']) self.grads_generated = False self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999 + self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None) # This is a half-measure that can be used between anomaly_detection and running a potentially problematic # trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips @@ -267,6 +268,13 @@ class ConfigurableStep(Module): else: self.nan_counter = 0 + if self.clip_grad_eps is not None: + for pg in opt.param_groups: + grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps) + if torch.isnan(grad_norm): + nan_found = True + self.nan_counter += 1 + if not nan_found: self.scaler.step(opt) self.scaler.update()