tacotron2, ready for prime time!
This commit is contained in:
parent
86fd3ad7fd
commit
1ff434218e
|
@ -67,8 +67,8 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
from data.audio.nv_tacotron_dataset import TextMelCollate as C
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
default_params = create_hparams()
|
||||
dataset_opt.update(default_params)
|
||||
dataset_opt = munchify(dataset_opt)
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
collate = C(dataset_opt.n_frames_per_step)
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
|
|
|
@ -111,8 +111,13 @@ class TextMelCollate():
|
|||
gate_padded[i, mel.size(1)-1:] = 1
|
||||
output_lengths[i] = mel.size(1)
|
||||
|
||||
return text_padded, input_lengths, mel_padded, gate_padded, \
|
||||
output_lengths
|
||||
return {
|
||||
'padded_text': text_padded,
|
||||
'input_lengths': input_lengths,
|
||||
'padded_mel': mel_padded,
|
||||
'padded_gate': gate_padded,
|
||||
'output_lengths': output_lengths
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
0
codes/models/tacotron2/__init__.py
Normal file
0
codes/models/tacotron2/__init__.py
Normal file
|
@ -1,9 +1,34 @@
|
|||
from torch import nn
|
||||
|
||||
from trainer.losses import ConfigurableLoss
|
||||
|
||||
class Tacotron2Loss(nn.Module):
|
||||
|
||||
class Tacotron2Loss(ConfigurableLoss):
|
||||
def __init__(self, opt_loss, env):
|
||||
super().__init__(opt_loss, env)
|
||||
self.mel_target_key = opt_loss['mel_target_key']
|
||||
self.mel_output_key = opt_loss['mel_output_key']
|
||||
self.mel_output_postnet_key = opt_loss['mel_output_postnet_key']
|
||||
self.gate_target_key = opt_loss['gate_target_key']
|
||||
self.gate_output_key = opt_loss['gate_output_key']
|
||||
|
||||
def forward(self, _, state):
|
||||
mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key]
|
||||
mel_target.requires_grad = False
|
||||
gate_target.requires_grad = False
|
||||
gate_target = gate_target.view(-1, 1)
|
||||
|
||||
mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[self.mel_output_postnet_key], state[self.gate_output_key]
|
||||
gate_out = gate_out.view(-1, 1)
|
||||
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
||||
nn.MSELoss()(mel_out_postnet, mel_target)
|
||||
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
||||
return mel_loss + gate_loss
|
||||
|
||||
|
||||
class Tacotron2LossRaw(nn.Module):
|
||||
def __init__(self):
|
||||
super(Tacotron2Loss, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def forward(self, model_output, targets):
|
||||
mel_target, gate_target = targets[0], targets[1]
|
||||
|
@ -16,4 +41,4 @@ class Tacotron2Loss(nn.Module):
|
|||
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
||||
nn.MSELoss()(mel_out_postnet, mel_target)
|
||||
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
||||
return mel_loss + gate_loss
|
||||
return mel_loss + gate_loss
|
||||
|
|
|
@ -3,9 +3,10 @@ from scipy.io.wavfile import read
|
|||
import torch
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths):
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len, device=lengths.device))
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
|
||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||
return mask
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ from munch import munchify
|
|||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from layers import ConvNorm, LinearNorm
|
||||
from models.tacotron2.layers import ConvNorm, LinearNorm
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from trainer.networks import register_model
|
||||
from taco_utils import to_gpu, get_mask_from_lengths
|
||||
from utils.util import opt_get
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from utils.util import opt_get, checkpoint
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
|
@ -74,7 +74,7 @@ class Attention(nn.Module):
|
|||
attention_hidden_state: attention rnn last output
|
||||
memory: encoder outputs
|
||||
processed_memory: processed encoder outputs
|
||||
attention_weights_cat: previous and cummulative attention weights
|
||||
attention_weights_cat: previous and cumulative attention weights
|
||||
mask: binary mask for padded data
|
||||
"""
|
||||
alignment = self.get_alignment_energies(
|
||||
|
@ -408,8 +408,7 @@ class Decoder(nn.Module):
|
|||
mel_outputs, gate_outputs, alignments = [], [], []
|
||||
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||
mel_output, gate_output, attention_weights = self.decode(
|
||||
decoder_input)
|
||||
mel_output, gate_output, attention_weights = self.decode(decoder_input)
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
gate_outputs += [gate_output.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
|
@ -474,23 +473,10 @@ class Tacotron2(nn.Module):
|
|||
self.decoder = Decoder(hparams)
|
||||
self.postnet = Postnet(hparams)
|
||||
|
||||
def parse_batch(self, batch):
|
||||
text_padded, input_lengths, mel_padded, gate_padded, \
|
||||
output_lengths = batch
|
||||
text_padded = to_gpu(text_padded).long()
|
||||
input_lengths = to_gpu(input_lengths).long()
|
||||
max_len = torch.max(input_lengths.data).item()
|
||||
mel_padded = to_gpu(mel_padded).float()
|
||||
gate_padded = to_gpu(gate_padded).float()
|
||||
output_lengths = to_gpu(output_lengths).long()
|
||||
|
||||
return (
|
||||
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
||||
(mel_padded, gate_padded))
|
||||
|
||||
def parse_output(self, outputs, output_lengths=None):
|
||||
if self.mask_padding and output_lengths is not None:
|
||||
mask = ~get_mask_from_lengths(output_lengths)
|
||||
mask_fill = outputs[0].shape[-1]
|
||||
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
|
||||
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
|
||||
mask = mask.permute(1, 0, 2)
|
||||
|
||||
|
@ -500,8 +486,7 @@ class Tacotron2(nn.Module):
|
|||
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs):
|
||||
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
|
||||
def forward(self, text_inputs, text_lengths, mels, output_lengths):
|
||||
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
||||
|
||||
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
||||
|
@ -535,9 +520,8 @@ class Tacotron2(nn.Module):
|
|||
|
||||
@register_model
|
||||
def register_nv_tacotron2(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
hparams = create_hparams()
|
||||
hparams.update(kw)
|
||||
hparams.update(opt_net)
|
||||
hparams = munchify(hparams)
|
||||
return Tacotron2(hparams)
|
||||
|
||||
|
|
|
@ -300,7 +300,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_latent_unet_diffusion_sm.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -275,6 +275,10 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Record visual outputs for usage in debugging and testing.
|
||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
def fix_image(img):
|
||||
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
|
||||
img = img.unsqueeze(dim=1)
|
||||
# Normalize so spectrogram is easier to view.
|
||||
img = (img - img.mean()) / img.std()
|
||||
if img.shape[1] > 3:
|
||||
img = img[:, :3, :, :]
|
||||
if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):
|
||||
|
|
44
codes/trainer/eval/mel_evaluator.py
Normal file
44
codes/trainer/eval/mel_evaluator.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
|
||||
import trainer.eval.evaluator as evaluator
|
||||
|
||||
from data import create_dataset
|
||||
from data.audio.nv_tacotron_dataset import TextMelCollate
|
||||
from models.tacotron2.loss import Tacotron2LossRaw
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# Evaluates the performance of a MEL spectrogram predictor.
|
||||
class MelEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=True)
|
||||
self.batch_sz = opt_eval['batch_size']
|
||||
self.dataset = create_dataset(opt_eval['dataset'])
|
||||
assert self.batch_sz is not None
|
||||
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1, collate_fn=TextMelCollate(n_frames_per_step=1))
|
||||
self.criterion = Tacotron2LossRaw()
|
||||
|
||||
def perform_eval(self):
|
||||
counter = 0
|
||||
total_error = 0
|
||||
self.model.eval()
|
||||
for batch in tqdm(self.dataloader):
|
||||
model_params = {
|
||||
'text_inputs': 'padded_text',
|
||||
'text_lengths': 'input_lengths',
|
||||
'mels': 'padded_mel',
|
||||
'output_lengths': 'output_lengths',
|
||||
}
|
||||
params = {k: batch[v].to(self.env['device']) for k, v in model_params.items()}
|
||||
with torch.no_grad():
|
||||
pred = self.model(**params)
|
||||
|
||||
targets = ['padded_mel', 'padded_gate']
|
||||
targets = [batch[t].to(self.env['device']) for t in targets]
|
||||
total_error += self.criterion(pred, targets).item()
|
||||
counter += 1
|
||||
self.model.train()
|
||||
|
||||
return {"validation-score": total_error / counter}
|
||||
|
|
@ -58,7 +58,7 @@ def create_loss(opt_loss, env):
|
|||
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
||||
elif type == 'nv_tacotron2_loss':
|
||||
from models.tacotron2.loss import Tacotron2Loss
|
||||
return Tacotron2Loss()
|
||||
return Tacotron2Loss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ class ConfigurableStep(Module):
|
|||
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
||||
self.grads_generated = False
|
||||
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
|
||||
self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)
|
||||
|
||||
# This is a half-measure that can be used between anomaly_detection and running a potentially problematic
|
||||
# trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips
|
||||
|
@ -267,6 +268,13 @@ class ConfigurableStep(Module):
|
|||
else:
|
||||
self.nan_counter = 0
|
||||
|
||||
if self.clip_grad_eps is not None:
|
||||
for pg in opt.param_groups:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps)
|
||||
if torch.isnan(grad_norm):
|
||||
nan_found = True
|
||||
self.nan_counter += 1
|
||||
|
||||
if not nan_found:
|
||||
self.scaler.step(opt)
|
||||
self.scaler.update()
|
||||
|
|
Loading…
Reference in New Issue
Block a user