tacotron2, ready for prime time!
This commit is contained in:
parent
86fd3ad7fd
commit
1ff434218e
|
@ -67,8 +67,8 @@ def create_dataset(dataset_opt, return_collate=False):
|
||||||
from data.audio.nv_tacotron_dataset import TextMelCollate as C
|
from data.audio.nv_tacotron_dataset import TextMelCollate as C
|
||||||
from models.tacotron2.hparams import create_hparams
|
from models.tacotron2.hparams import create_hparams
|
||||||
default_params = create_hparams()
|
default_params = create_hparams()
|
||||||
dataset_opt.update(default_params)
|
default_params.update(dataset_opt)
|
||||||
dataset_opt = munchify(dataset_opt)
|
dataset_opt = munchify(default_params)
|
||||||
collate = C(dataset_opt.n_frames_per_step)
|
collate = C(dataset_opt.n_frames_per_step)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||||
|
|
|
@ -111,8 +111,13 @@ class TextMelCollate():
|
||||||
gate_padded[i, mel.size(1)-1:] = 1
|
gate_padded[i, mel.size(1)-1:] = 1
|
||||||
output_lengths[i] = mel.size(1)
|
output_lengths[i] = mel.size(1)
|
||||||
|
|
||||||
return text_padded, input_lengths, mel_padded, gate_padded, \
|
return {
|
||||||
output_lengths
|
'padded_text': text_padded,
|
||||||
|
'input_lengths': input_lengths,
|
||||||
|
'padded_mel': mel_padded,
|
||||||
|
'padded_gate': gate_padded,
|
||||||
|
'output_lengths': output_lengths
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
0
codes/models/tacotron2/__init__.py
Normal file
0
codes/models/tacotron2/__init__.py
Normal file
|
@ -1,9 +1,34 @@
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from trainer.losses import ConfigurableLoss
|
||||||
|
|
||||||
class Tacotron2Loss(nn.Module):
|
|
||||||
|
class Tacotron2Loss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt_loss, env):
|
||||||
|
super().__init__(opt_loss, env)
|
||||||
|
self.mel_target_key = opt_loss['mel_target_key']
|
||||||
|
self.mel_output_key = opt_loss['mel_output_key']
|
||||||
|
self.mel_output_postnet_key = opt_loss['mel_output_postnet_key']
|
||||||
|
self.gate_target_key = opt_loss['gate_target_key']
|
||||||
|
self.gate_output_key = opt_loss['gate_output_key']
|
||||||
|
|
||||||
|
def forward(self, _, state):
|
||||||
|
mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key]
|
||||||
|
mel_target.requires_grad = False
|
||||||
|
gate_target.requires_grad = False
|
||||||
|
gate_target = gate_target.view(-1, 1)
|
||||||
|
|
||||||
|
mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[self.mel_output_postnet_key], state[self.gate_output_key]
|
||||||
|
gate_out = gate_out.view(-1, 1)
|
||||||
|
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
||||||
|
nn.MSELoss()(mel_out_postnet, mel_target)
|
||||||
|
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
||||||
|
return mel_loss + gate_loss
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2LossRaw(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Tacotron2Loss, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, model_output, targets):
|
def forward(self, model_output, targets):
|
||||||
mel_target, gate_target = targets[0], targets[1]
|
mel_target, gate_target = targets[0], targets[1]
|
||||||
|
@ -16,4 +41,4 @@ class Tacotron2Loss(nn.Module):
|
||||||
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
||||||
nn.MSELoss()(mel_out_postnet, mel_target)
|
nn.MSELoss()(mel_out_postnet, mel_target)
|
||||||
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
||||||
return mel_loss + gate_loss
|
return mel_loss + gate_loss
|
||||||
|
|
|
@ -3,9 +3,10 @@ from scipy.io.wavfile import read
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def get_mask_from_lengths(lengths):
|
def get_mask_from_lengths(lengths, max_len=None):
|
||||||
max_len = torch.max(lengths).item()
|
if max_len is None:
|
||||||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len, device=lengths.device))
|
max_len = torch.max(lengths).item()
|
||||||
|
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
|
||||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,11 @@ from munch import munchify
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from layers import ConvNorm, LinearNorm
|
from models.tacotron2.layers import ConvNorm, LinearNorm
|
||||||
from models.tacotron2.hparams import create_hparams
|
from models.tacotron2.hparams import create_hparams
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from taco_utils import to_gpu, get_mask_from_lengths
|
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get, checkpoint
|
||||||
|
|
||||||
|
|
||||||
class LocationLayer(nn.Module):
|
class LocationLayer(nn.Module):
|
||||||
|
@ -74,7 +74,7 @@ class Attention(nn.Module):
|
||||||
attention_hidden_state: attention rnn last output
|
attention_hidden_state: attention rnn last output
|
||||||
memory: encoder outputs
|
memory: encoder outputs
|
||||||
processed_memory: processed encoder outputs
|
processed_memory: processed encoder outputs
|
||||||
attention_weights_cat: previous and cummulative attention weights
|
attention_weights_cat: previous and cumulative attention weights
|
||||||
mask: binary mask for padded data
|
mask: binary mask for padded data
|
||||||
"""
|
"""
|
||||||
alignment = self.get_alignment_energies(
|
alignment = self.get_alignment_energies(
|
||||||
|
@ -408,8 +408,7 @@ class Decoder(nn.Module):
|
||||||
mel_outputs, gate_outputs, alignments = [], [], []
|
mel_outputs, gate_outputs, alignments = [], [], []
|
||||||
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||||
decoder_input = decoder_inputs[len(mel_outputs)]
|
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||||
mel_output, gate_output, attention_weights = self.decode(
|
mel_output, gate_output, attention_weights = self.decode(decoder_input)
|
||||||
decoder_input)
|
|
||||||
mel_outputs += [mel_output.squeeze(1)]
|
mel_outputs += [mel_output.squeeze(1)]
|
||||||
gate_outputs += [gate_output.squeeze(1)]
|
gate_outputs += [gate_output.squeeze(1)]
|
||||||
alignments += [attention_weights]
|
alignments += [attention_weights]
|
||||||
|
@ -474,23 +473,10 @@ class Tacotron2(nn.Module):
|
||||||
self.decoder = Decoder(hparams)
|
self.decoder = Decoder(hparams)
|
||||||
self.postnet = Postnet(hparams)
|
self.postnet = Postnet(hparams)
|
||||||
|
|
||||||
def parse_batch(self, batch):
|
|
||||||
text_padded, input_lengths, mel_padded, gate_padded, \
|
|
||||||
output_lengths = batch
|
|
||||||
text_padded = to_gpu(text_padded).long()
|
|
||||||
input_lengths = to_gpu(input_lengths).long()
|
|
||||||
max_len = torch.max(input_lengths.data).item()
|
|
||||||
mel_padded = to_gpu(mel_padded).float()
|
|
||||||
gate_padded = to_gpu(gate_padded).float()
|
|
||||||
output_lengths = to_gpu(output_lengths).long()
|
|
||||||
|
|
||||||
return (
|
|
||||||
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
|
||||||
(mel_padded, gate_padded))
|
|
||||||
|
|
||||||
def parse_output(self, outputs, output_lengths=None):
|
def parse_output(self, outputs, output_lengths=None):
|
||||||
if self.mask_padding and output_lengths is not None:
|
if self.mask_padding and output_lengths is not None:
|
||||||
mask = ~get_mask_from_lengths(output_lengths)
|
mask_fill = outputs[0].shape[-1]
|
||||||
|
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
|
||||||
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
|
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
|
||||||
mask = mask.permute(1, 0, 2)
|
mask = mask.permute(1, 0, 2)
|
||||||
|
|
||||||
|
@ -500,8 +486,7 @@ class Tacotron2(nn.Module):
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, text_inputs, text_lengths, mels, output_lengths):
|
||||||
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
|
|
||||||
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
||||||
|
|
||||||
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
||||||
|
@ -535,9 +520,8 @@ class Tacotron2(nn.Module):
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_nv_tacotron2(opt_net, opt):
|
def register_nv_tacotron2(opt_net, opt):
|
||||||
kw = opt_get(opt_net, ['kwargs'], {})
|
|
||||||
hparams = create_hparams()
|
hparams = create_hparams()
|
||||||
hparams.update(kw)
|
hparams.update(opt_net)
|
||||||
hparams = munchify(hparams)
|
hparams = munchify(hparams)
|
||||||
return Tacotron2(hparams)
|
return Tacotron2(hparams)
|
||||||
|
|
||||||
|
|
|
@ -300,7 +300,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_latent_unet_diffusion_sm.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -275,6 +275,10 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||||
def fix_image(img):
|
def fix_image(img):
|
||||||
|
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
|
||||||
|
img = img.unsqueeze(dim=1)
|
||||||
|
# Normalize so spectrogram is easier to view.
|
||||||
|
img = (img - img.mean()) / img.std()
|
||||||
if img.shape[1] > 3:
|
if img.shape[1] > 3:
|
||||||
img = img[:, :3, :, :]
|
img = img[:, :3, :, :]
|
||||||
if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):
|
if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):
|
||||||
|
|
44
codes/trainer/eval/mel_evaluator.py
Normal file
44
codes/trainer/eval/mel_evaluator.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import trainer.eval.evaluator as evaluator
|
||||||
|
|
||||||
|
from data import create_dataset
|
||||||
|
from data.audio.nv_tacotron_dataset import TextMelCollate
|
||||||
|
from models.tacotron2.loss import Tacotron2LossRaw
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# Evaluates the performance of a MEL spectrogram predictor.
|
||||||
|
class MelEvaluator(evaluator.Evaluator):
|
||||||
|
def __init__(self, model, opt_eval, env):
|
||||||
|
super().__init__(model, opt_eval, env, uses_all_ddp=True)
|
||||||
|
self.batch_sz = opt_eval['batch_size']
|
||||||
|
self.dataset = create_dataset(opt_eval['dataset'])
|
||||||
|
assert self.batch_sz is not None
|
||||||
|
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1, collate_fn=TextMelCollate(n_frames_per_step=1))
|
||||||
|
self.criterion = Tacotron2LossRaw()
|
||||||
|
|
||||||
|
def perform_eval(self):
|
||||||
|
counter = 0
|
||||||
|
total_error = 0
|
||||||
|
self.model.eval()
|
||||||
|
for batch in tqdm(self.dataloader):
|
||||||
|
model_params = {
|
||||||
|
'text_inputs': 'padded_text',
|
||||||
|
'text_lengths': 'input_lengths',
|
||||||
|
'mels': 'padded_mel',
|
||||||
|
'output_lengths': 'output_lengths',
|
||||||
|
}
|
||||||
|
params = {k: batch[v].to(self.env['device']) for k, v in model_params.items()}
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = self.model(**params)
|
||||||
|
|
||||||
|
targets = ['padded_mel', 'padded_gate']
|
||||||
|
targets = [batch[t].to(self.env['device']) for t in targets]
|
||||||
|
total_error += self.criterion(pred, targets).item()
|
||||||
|
counter += 1
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
return {"validation-score": total_error / counter}
|
||||||
|
|
|
@ -58,7 +58,7 @@ def create_loss(opt_loss, env):
|
||||||
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
||||||
elif type == 'nv_tacotron2_loss':
|
elif type == 'nv_tacotron2_loss':
|
||||||
from models.tacotron2.loss import Tacotron2Loss
|
from models.tacotron2.loss import Tacotron2Loss
|
||||||
return Tacotron2Loss()
|
return Tacotron2Loss(opt_loss, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ class ConfigurableStep(Module):
|
||||||
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
||||||
self.grads_generated = False
|
self.grads_generated = False
|
||||||
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
|
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
|
||||||
|
self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)
|
||||||
|
|
||||||
# This is a half-measure that can be used between anomaly_detection and running a potentially problematic
|
# This is a half-measure that can be used between anomaly_detection and running a potentially problematic
|
||||||
# trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips
|
# trainer bare. With this turned on, the optimizer will not step() if a nan grad is detected. If a model trips
|
||||||
|
@ -267,6 +268,13 @@ class ConfigurableStep(Module):
|
||||||
else:
|
else:
|
||||||
self.nan_counter = 0
|
self.nan_counter = 0
|
||||||
|
|
||||||
|
if self.clip_grad_eps is not None:
|
||||||
|
for pg in opt.param_groups:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps)
|
||||||
|
if torch.isnan(grad_norm):
|
||||||
|
nan_found = True
|
||||||
|
self.nan_counter += 1
|
||||||
|
|
||||||
if not nan_found:
|
if not nan_found:
|
||||||
self.scaler.step(opt)
|
self.scaler.step(opt)
|
||||||
self.scaler.update()
|
self.scaler.update()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user