Initial checkin of nvidia tacotron model & dataset
These two are tested, full support for training to come.
This commit is contained in:
parent
3801d5d55e
commit
86fd3ad7fd
|
@ -2,11 +2,12 @@
|
|||
import logging
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from munch import munchify
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None):
|
||||
phase = dataset_opt['phase']
|
||||
if phase == 'train':
|
||||
if opt_get(opt, ['dist'], False):
|
||||
|
@ -21,15 +22,17 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
|||
shuffle = True
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||
num_workers=num_workers, sampler=sampler, drop_last=True,
|
||||
pin_memory=True)
|
||||
pin_memory=True, collate_fn=collate_fn)
|
||||
else:
|
||||
batch_size = dataset_opt['batch_size'] or 1
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
|
||||
pin_memory=True)
|
||||
pin_memory=True, collate_fn=collate_fn)
|
||||
|
||||
|
||||
def create_dataset(dataset_opt):
|
||||
def create_dataset(dataset_opt, return_collate=False):
|
||||
mode = dataset_opt['mode']
|
||||
collate = None
|
||||
|
||||
# datasets for image restoration
|
||||
if mode == 'fullimage':
|
||||
from data.full_image_dataset import FullImageDataset as D
|
||||
|
@ -59,8 +62,19 @@ def create_dataset(dataset_opt):
|
|||
from data.random_dataset import RandomDataset as D
|
||||
elif mode == 'zipfile':
|
||||
from data.zip_file_dataset import ZipFileDataset as D
|
||||
elif mode == 'nv_tacotron':
|
||||
from data.audio.nv_tacotron_dataset import TextMelLoader as D
|
||||
from data.audio.nv_tacotron_dataset import TextMelCollate as C
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
default_params = create_hparams()
|
||||
dataset_opt.update(default_params)
|
||||
dataset_opt = munchify(dataset_opt)
|
||||
collate = C(dataset_opt.n_frames_per_step)
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
||||
return dataset
|
||||
if return_collate:
|
||||
return dataset, collate
|
||||
else:
|
||||
return dataset
|
||||
|
|
127
codes/data/audio/nv_tacotron_dataset.py
Normal file
127
codes/data/audio/nv_tacotron_dataset.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
import models.tacotron2.layers as layers
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
|
||||
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
|
||||
|
||||
class TextMelLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio,text pairs
|
||||
2) normalizes text and converts them to sequences of one-hot vectors
|
||||
3) computes mel-spectrograms from audio files.
|
||||
"""
|
||||
def __init__(self, hparams):
|
||||
self.path = os.path.dirname(hparams['path'])
|
||||
self.audiopaths_and_text = load_filepaths_and_text(hparams['path'])
|
||||
self.text_cleaners = hparams.text_cleaners
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.load_mel_from_disk = hparams.load_mel_from_disk
|
||||
self.stft = layers.TacotronSTFT(
|
||||
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
||||
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
|
||||
hparams.mel_fmax)
|
||||
random.seed(hparams.seed)
|
||||
random.shuffle(self.audiopaths_and_text)
|
||||
|
||||
def get_mel_text_pair(self, audiopath_and_text):
|
||||
# separate filename and text
|
||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||
audiopath = os.path.join(self.path, audiopath)
|
||||
text = self.get_text(text)
|
||||
mel = self.get_mel(audiopath)
|
||||
return (text, mel)
|
||||
|
||||
def get_mel(self, filename):
|
||||
if not self.load_mel_from_disk:
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
if sampling_rate != self.stft.sampling_rate:
|
||||
raise ValueError("{} {} SR doesn't match target {} SR".format(
|
||||
sampling_rate, self.stft.sampling_rate))
|
||||
audio_norm = audio / self.max_wav_value
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
||||
melspec = self.stft.mel_spectrogram(audio_norm)
|
||||
melspec = torch.squeeze(melspec, 0)
|
||||
else:
|
||||
melspec = torch.from_numpy(np.load(filename))
|
||||
assert melspec.size(0) == self.stft.n_mel_channels, (
|
||||
'Mel dimension mismatch: given {}, expected {}'.format(
|
||||
melspec.size(0), self.stft.n_mel_channels))
|
||||
|
||||
return melspec
|
||||
|
||||
def get_text(self, text):
|
||||
text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
|
||||
return text_norm
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.get_mel_text_pair(self.audiopaths_and_text[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_and_text)
|
||||
|
||||
|
||||
class TextMelCollate():
|
||||
""" Zero-pads model inputs and targets based on number of frames per setep
|
||||
"""
|
||||
def __init__(self, n_frames_per_step):
|
||||
self.n_frames_per_step = n_frames_per_step
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Collate's training batch from normalized text and mel-spectrogram
|
||||
PARAMS
|
||||
------
|
||||
batch: [text_normalized, mel_normalized]
|
||||
"""
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
input_lengths, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([len(x[0]) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
max_input_len = input_lengths[0]
|
||||
|
||||
text_padded = torch.LongTensor(len(batch), max_input_len)
|
||||
text_padded.zero_()
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
text = batch[ids_sorted_decreasing[i]][0]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
|
||||
# Right zero-pad mel-spec
|
||||
num_mels = batch[0][1].size(0)
|
||||
max_target_len = max([x[1].size(1) for x in batch])
|
||||
if max_target_len % self.n_frames_per_step != 0:
|
||||
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
|
||||
assert max_target_len % self.n_frames_per_step == 0
|
||||
|
||||
# include mel padded and gate padded
|
||||
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
|
||||
mel_padded.zero_()
|
||||
gate_padded = torch.FloatTensor(len(batch), max_target_len)
|
||||
gate_padded.zero_()
|
||||
output_lengths = torch.LongTensor(len(batch))
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
mel = batch[ids_sorted_decreasing[i]][1]
|
||||
mel_padded[i, :, :mel.size(1)] = mel
|
||||
gate_padded[i, mel.size(1)-1:] = 1
|
||||
output_lengths[i] = mel.size(1)
|
||||
|
||||
return text_padded, input_lengths, mel_padded, gate_padded, \
|
||||
output_lengths
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'nv_tacotron',
|
||||
'path': 'E:\\4k6k\\datasets\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
||||
|
||||
}
|
||||
from data import create_dataset
|
||||
ds = create_dataset(params)
|
||||
j = ds[0]
|
||||
print(j)
|
32
codes/models/tacotron2/LICENSE
Normal file
32
codes/models/tacotron2/LICENSE
Normal file
|
@ -0,0 +1,32 @@
|
|||
This directory contains works with the below licenses, which should be considered in addition
|
||||
to the base repository license.
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2018, NVIDIA Corporation
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
93
codes/models/tacotron2/audio_processing.py
Normal file
93
codes/models/tacotron2/audio_processing.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from scipy.signal import get_window
|
||||
import librosa.util as librosa_util
|
||||
|
||||
|
||||
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
||||
n_fft=800, dtype=np.float32, norm=None):
|
||||
"""
|
||||
# from librosa 0.6
|
||||
Compute the sum-square envelope of a window function at a given hop length.
|
||||
|
||||
This is used to estimate modulation effects induced by windowing
|
||||
observations in short-time fourier transforms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
window : string, tuple, number, callable, or list-like
|
||||
Window specification, as in `get_window`
|
||||
|
||||
n_frames : int > 0
|
||||
The number of analysis frames
|
||||
|
||||
hop_length : int > 0
|
||||
The number of samples to advance between frames
|
||||
|
||||
win_length : [optional]
|
||||
The length of the window function. By default, this matches `n_fft`.
|
||||
|
||||
n_fft : int > 0
|
||||
The length of each analysis frame.
|
||||
|
||||
dtype : np.dtype
|
||||
The data type of the output
|
||||
|
||||
Returns
|
||||
-------
|
||||
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
||||
The sum-squared envelope of the window function
|
||||
"""
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
n = n_fft + hop_length * (n_frames - 1)
|
||||
x = np.zeros(n, dtype=dtype)
|
||||
|
||||
# Compute the squared window at the desired length
|
||||
win_sq = get_window(window, win_length, fftbins=True)
|
||||
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
||||
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
||||
|
||||
# Fill the envelope
|
||||
for i in range(n_frames):
|
||||
sample = i * hop_length
|
||||
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||
return x
|
||||
|
||||
|
||||
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
magnitudes: spectrogram magnitudes
|
||||
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
||||
"""
|
||||
|
||||
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
||||
angles = angles.astype(np.float32)
|
||||
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
||||
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
||||
|
||||
for i in range(n_iters):
|
||||
_, angles = stft_fn.transform(signal)
|
||||
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
||||
return signal
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
95
codes/models/tacotron2/hparams.py
Normal file
95
codes/models/tacotron2/hparams.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import tensorflow as tf
|
||||
from models.tacotron2.text import symbols
|
||||
|
||||
|
||||
def create_hparams(hparams_string=None, verbose=False):
|
||||
"""Create model hyperparameters. Parse nondefault from given string."""
|
||||
|
||||
hparams = dict(
|
||||
################################
|
||||
# Experiment Parameters #
|
||||
################################
|
||||
epochs=500,
|
||||
iters_per_checkpoint=1000,
|
||||
seed=1234,
|
||||
dynamic_loss_scaling=True,
|
||||
fp16_run=False,
|
||||
distributed_run=False,
|
||||
dist_backend="nccl",
|
||||
dist_url="tcp://localhost:54321",
|
||||
cudnn_enabled=True,
|
||||
cudnn_benchmark=False,
|
||||
ignore_layers=['embedding.weight'],
|
||||
|
||||
################################
|
||||
# Data Parameters #
|
||||
################################
|
||||
load_mel_from_disk=False,
|
||||
training_files='filelists/ljs_audio_text_train_filelist.txt',
|
||||
validation_files='filelists/ljs_audio_text_val_filelist.txt',
|
||||
text_cleaners=['english_cleaners'],
|
||||
|
||||
################################
|
||||
# Audio Parameters #
|
||||
################################
|
||||
max_wav_value=32768.0,
|
||||
sampling_rate=22050,
|
||||
filter_length=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mel_channels=80,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=8000.0,
|
||||
|
||||
################################
|
||||
# Model Parameters #
|
||||
################################
|
||||
n_symbols=len(symbols),
|
||||
symbols_embedding_dim=512,
|
||||
|
||||
# Encoder parameters
|
||||
encoder_kernel_size=5,
|
||||
encoder_n_convolutions=3,
|
||||
encoder_embedding_dim=512,
|
||||
|
||||
# Decoder parameters
|
||||
n_frames_per_step=1, # currently only 1 is supported
|
||||
decoder_rnn_dim=1024,
|
||||
prenet_dim=256,
|
||||
max_decoder_steps=1000,
|
||||
gate_threshold=0.5,
|
||||
p_attention_dropout=0.1,
|
||||
p_decoder_dropout=0.1,
|
||||
|
||||
# Attention parameters
|
||||
attention_rnn_dim=1024,
|
||||
attention_dim=128,
|
||||
|
||||
# Location Layer parameters
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
|
||||
# Mel-post processing network parameters
|
||||
postnet_embedding_dim=512,
|
||||
postnet_kernel_size=5,
|
||||
postnet_n_convolutions=5,
|
||||
|
||||
################################
|
||||
# Optimization Hyperparameters #
|
||||
################################
|
||||
use_saved_learning_rate=False,
|
||||
learning_rate=1e-3,
|
||||
weight_decay=1e-6,
|
||||
grad_clip_thresh=1.0,
|
||||
batch_size=64,
|
||||
mask_padding=True # set model's padded outputs to padded values
|
||||
)
|
||||
|
||||
if hparams_string:
|
||||
tf.logging.info('Parsing command line hparams: %s', hparams_string)
|
||||
hparams.parse(hparams_string)
|
||||
|
||||
if verbose:
|
||||
tf.logging.info('Final parsed hparams: %s', hparams.values())
|
||||
|
||||
return hparams
|
80
codes/models/tacotron2/layers.py
Normal file
80
codes/models/tacotron2/layers.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import torch
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from models.tacotron2.audio_processing import dynamic_range_compression
|
||||
from models.tacotron2.audio_processing import dynamic_range_decompression
|
||||
from models.tacotron2.stft import STFT
|
||||
|
||||
|
||||
class LinearNorm(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(LinearNorm, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
|
||||
class ConvNorm(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
||||
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
||||
super(ConvNorm, self).__init__()
|
||||
if padding is None:
|
||||
assert(kernel_size % 2 == 1)
|
||||
padding = int(dilation * (kernel_size - 1) / 2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, signal):
|
||||
conv_signal = self.conv(signal)
|
||||
return conv_signal
|
||||
|
||||
|
||||
class TacotronSTFT(torch.nn.Module):
|
||||
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
||||
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
||||
mel_fmax=8000.0):
|
||||
super(TacotronSTFT, self).__init__()
|
||||
self.n_mel_channels = n_mel_channels
|
||||
self.sampling_rate = sampling_rate
|
||||
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
||||
mel_basis = librosa_mel_fn(
|
||||
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
|
||||
mel_basis = torch.from_numpy(mel_basis).float()
|
||||
self.register_buffer('mel_basis', mel_basis)
|
||||
|
||||
def spectral_normalize(self, magnitudes):
|
||||
output = dynamic_range_compression(magnitudes)
|
||||
return output
|
||||
|
||||
def spectral_de_normalize(self, magnitudes):
|
||||
output = dynamic_range_decompression(magnitudes)
|
||||
return output
|
||||
|
||||
def mel_spectrogram(self, y):
|
||||
"""Computes mel-spectrograms from a batch of waves
|
||||
PARAMS
|
||||
------
|
||||
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
||||
"""
|
||||
assert(torch.min(y.data) >= -1)
|
||||
assert(torch.max(y.data) <= 1)
|
||||
|
||||
magnitudes, phases = self.stft_fn.transform(y)
|
||||
magnitudes = magnitudes.data
|
||||
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
||||
mel_output = self.spectral_normalize(mel_output)
|
||||
return mel_output
|
19
codes/models/tacotron2/loss.py
Normal file
19
codes/models/tacotron2/loss.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from torch import nn
|
||||
|
||||
|
||||
class Tacotron2Loss(nn.Module):
|
||||
def __init__(self):
|
||||
super(Tacotron2Loss, self).__init__()
|
||||
|
||||
def forward(self, model_output, targets):
|
||||
mel_target, gate_target = targets[0], targets[1]
|
||||
mel_target.requires_grad = False
|
||||
gate_target.requires_grad = False
|
||||
gate_target = gate_target.view(-1, 1)
|
||||
|
||||
mel_out, mel_out_postnet, gate_out, _ = model_output
|
||||
gate_out = gate_out.view(-1, 1)
|
||||
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
||||
nn.MSELoss()(mel_out_postnet, mel_target)
|
||||
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
||||
return mel_loss + gate_loss
|
141
codes/models/tacotron2/stft.py
Normal file
141
codes/models/tacotron2/stft.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
"""
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2017, Prem Seetharaman
|
||||
All rights reserved.
|
||||
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this
|
||||
list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from this
|
||||
software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from scipy.signal import get_window
|
||||
from librosa.util import pad_center, tiny
|
||||
from models.tacotron2.audio_processing import window_sumsquare
|
||||
|
||||
|
||||
class STFT(torch.nn.Module):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
||||
window='hann'):
|
||||
super(STFT, self).__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.forward_transform = None
|
||||
scale = self.filter_length / self.hop_length
|
||||
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||
|
||||
cutoff = int((self.filter_length / 2 + 1))
|
||||
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
||||
np.imag(fourier_basis[:cutoff, :])])
|
||||
|
||||
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
||||
inverse_basis = torch.FloatTensor(
|
||||
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
||||
|
||||
if window is not None:
|
||||
assert(filter_length >= win_length)
|
||||
# get window and zero center pad it to filter_length
|
||||
fft_window = get_window(window, win_length, fftbins=True)
|
||||
fft_window = pad_center(fft_window, filter_length)
|
||||
fft_window = torch.from_numpy(fft_window).float()
|
||||
|
||||
# window the bases
|
||||
forward_basis *= fft_window
|
||||
inverse_basis *= fft_window
|
||||
|
||||
self.register_buffer('forward_basis', forward_basis.float())
|
||||
self.register_buffer('inverse_basis', inverse_basis.float())
|
||||
|
||||
def transform(self, input_data):
|
||||
num_batches = input_data.size(0)
|
||||
num_samples = input_data.size(1)
|
||||
|
||||
self.num_samples = num_samples
|
||||
|
||||
# similar to librosa, reflect-pad the input
|
||||
input_data = input_data.view(num_batches, 1, num_samples)
|
||||
input_data = F.pad(
|
||||
input_data.unsqueeze(1),
|
||||
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
||||
mode='reflect')
|
||||
input_data = input_data.squeeze(1)
|
||||
|
||||
forward_transform = F.conv1d(
|
||||
input_data,
|
||||
Variable(self.forward_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0)
|
||||
|
||||
cutoff = int((self.filter_length / 2) + 1)
|
||||
real_part = forward_transform[:, :cutoff, :]
|
||||
imag_part = forward_transform[:, cutoff:, :]
|
||||
|
||||
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
||||
phase = torch.autograd.Variable(
|
||||
torch.atan2(imag_part.data, real_part.data))
|
||||
|
||||
return magnitude, phase
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
recombine_magnitude_phase = torch.cat(
|
||||
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
||||
|
||||
inverse_transform = F.conv_transpose1d(
|
||||
recombine_magnitude_phase,
|
||||
Variable(self.inverse_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0)
|
||||
|
||||
if self.window is not None:
|
||||
window_sum = window_sumsquare(
|
||||
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
||||
win_length=self.win_length, n_fft=self.filter_length,
|
||||
dtype=np.float32)
|
||||
# remove modulation effects
|
||||
approx_nonzero_indices = torch.from_numpy(
|
||||
np.where(window_sum > tiny(window_sum))[0])
|
||||
window_sum = torch.autograd.Variable(
|
||||
torch.from_numpy(window_sum), requires_grad=False)
|
||||
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
||||
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||
|
||||
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
||||
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
||||
|
||||
return inverse_transform
|
||||
|
||||
def forward(self, input_data):
|
||||
self.magnitude, self.phase = self.transform(input_data)
|
||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||
return reconstruction
|
29
codes/models/tacotron2/taco_utils.py
Normal file
29
codes/models/tacotron2/taco_utils.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import numpy as np
|
||||
from scipy.io.wavfile import read
|
||||
import torch
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths):
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len, device=lengths.device))
|
||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||
return mask
|
||||
|
||||
|
||||
def load_wav_to_torch(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
||||
|
||||
|
||||
def load_filepaths_and_text(filename, split="|"):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def to_gpu(x):
|
||||
x = x.contiguous()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda(non_blocking=True)
|
||||
return torch.autograd.Variable(x)
|
549
codes/models/tacotron2/tacotron2.py
Normal file
549
codes/models/tacotron2/tacotron2.py
Normal file
|
@ -0,0 +1,549 @@
|
|||
from math import sqrt
|
||||
import torch
|
||||
from munch import munchify
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from layers import ConvNorm, LinearNorm
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from trainer.networks import register_model
|
||||
from taco_utils import to_gpu, get_mask_from_lengths
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
def __init__(self, attention_n_filters, attention_kernel_size,
|
||||
attention_dim):
|
||||
super(LocationLayer, self).__init__()
|
||||
padding = int((attention_kernel_size - 1) / 2)
|
||||
self.location_conv = ConvNorm(2, attention_n_filters,
|
||||
kernel_size=attention_kernel_size,
|
||||
padding=padding, bias=False, stride=1,
|
||||
dilation=1)
|
||||
self.location_dense = LinearNorm(attention_n_filters, attention_dim,
|
||||
bias=False, w_init_gain='tanh')
|
||||
|
||||
def forward(self, attention_weights_cat):
|
||||
processed_attention = self.location_conv(attention_weights_cat)
|
||||
processed_attention = processed_attention.transpose(1, 2)
|
||||
processed_attention = self.location_dense(processed_attention)
|
||||
return processed_attention
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
attention_location_n_filters, attention_location_kernel_size):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
|
||||
bias=False, w_init_gain='tanh')
|
||||
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
|
||||
w_init_gain='tanh')
|
||||
self.v = LinearNorm(attention_dim, 1, bias=False)
|
||||
self.location_layer = LocationLayer(attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
attention_dim)
|
||||
self.score_mask_value = -float("inf")
|
||||
|
||||
def get_alignment_energies(self, query, processed_memory,
|
||||
attention_weights_cat):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
||||
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
||||
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
alignment (batch, max_time)
|
||||
"""
|
||||
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_weights_cat)
|
||||
energies = self.v(torch.tanh(
|
||||
processed_query + processed_attention_weights + processed_memory))
|
||||
|
||||
energies = energies.squeeze(-1)
|
||||
return energies
|
||||
|
||||
def forward(self, attention_hidden_state, memory, processed_memory,
|
||||
attention_weights_cat, mask):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
attention_hidden_state: attention rnn last output
|
||||
memory: encoder outputs
|
||||
processed_memory: processed encoder outputs
|
||||
attention_weights_cat: previous and cummulative attention weights
|
||||
mask: binary mask for padded data
|
||||
"""
|
||||
alignment = self.get_alignment_energies(
|
||||
attention_hidden_state, processed_memory, attention_weights_cat)
|
||||
|
||||
if mask is not None:
|
||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
attention_weights = F.softmax(alignment, dim=1)
|
||||
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
||||
attention_context = attention_context.squeeze(1)
|
||||
|
||||
return attention_context, attention_weights
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self, in_dim, sizes):
|
||||
super(Prenet, self).__init__()
|
||||
in_sizes = [in_dim] + sizes[:-1]
|
||||
self.layers = nn.ModuleList(
|
||||
[LinearNorm(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_sizes, sizes)])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
||||
return x
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
"""Postnet
|
||||
- Five 1-d convolution with 512 channels and kernel size 5
|
||||
"""
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
|
||||
kernel_size=hparams.postnet_kernel_size, stride=1,
|
||||
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hparams.postnet_embedding_dim))
|
||||
)
|
||||
|
||||
for i in range(1, hparams.postnet_n_convolutions - 1):
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
ConvNorm(hparams.postnet_embedding_dim,
|
||||
hparams.postnet_embedding_dim,
|
||||
kernel_size=hparams.postnet_kernel_size, stride=1,
|
||||
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hparams.postnet_embedding_dim))
|
||||
)
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
|
||||
kernel_size=hparams.postnet_kernel_size, stride=1,
|
||||
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='linear'),
|
||||
nn.BatchNorm1d(hparams.n_mel_channels))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(len(self.convolutions) - 1):
|
||||
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
||||
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Encoder module:
|
||||
- Three 1-d convolution banks
|
||||
- Bidirectional LSTM
|
||||
"""
|
||||
def __init__(self, hparams):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
convolutions = []
|
||||
for _ in range(hparams.encoder_n_convolutions):
|
||||
conv_layer = nn.Sequential(
|
||||
ConvNorm(hparams.encoder_embedding_dim,
|
||||
hparams.encoder_embedding_dim,
|
||||
kernel_size=hparams.encoder_kernel_size, stride=1,
|
||||
padding=int((hparams.encoder_kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='relu'),
|
||||
nn.BatchNorm1d(hparams.encoder_embedding_dim))
|
||||
convolutions.append(conv_layer)
|
||||
self.convolutions = nn.ModuleList(convolutions)
|
||||
|
||||
self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
|
||||
int(hparams.encoder_embedding_dim / 2), 1,
|
||||
batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, x, input_lengths):
|
||||
for conv in self.convolutions:
|
||||
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# pytorch tensor are not reversible, hence the conversion
|
||||
input_lengths = input_lengths.cpu().numpy()
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x, input_lengths, batch_first=True)
|
||||
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, _ = self.lstm(x)
|
||||
|
||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
outputs, batch_first=True)
|
||||
|
||||
return outputs
|
||||
|
||||
def inference(self, x):
|
||||
for conv in self.convolutions:
|
||||
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, _ = self.lstm(x)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, hparams):
|
||||
super(Decoder, self).__init__()
|
||||
self.n_mel_channels = hparams.n_mel_channels
|
||||
self.n_frames_per_step = hparams.n_frames_per_step
|
||||
self.encoder_embedding_dim = hparams.encoder_embedding_dim
|
||||
self.attention_rnn_dim = hparams.attention_rnn_dim
|
||||
self.decoder_rnn_dim = hparams.decoder_rnn_dim
|
||||
self.prenet_dim = hparams.prenet_dim
|
||||
self.max_decoder_steps = hparams.max_decoder_steps
|
||||
self.gate_threshold = hparams.gate_threshold
|
||||
self.p_attention_dropout = hparams.p_attention_dropout
|
||||
self.p_decoder_dropout = hparams.p_decoder_dropout
|
||||
|
||||
self.prenet = Prenet(
|
||||
hparams.n_mel_channels * hparams.n_frames_per_step,
|
||||
[hparams.prenet_dim, hparams.prenet_dim])
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(
|
||||
hparams.prenet_dim + hparams.encoder_embedding_dim,
|
||||
hparams.attention_rnn_dim)
|
||||
|
||||
self.attention_layer = Attention(
|
||||
hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
|
||||
hparams.attention_dim, hparams.attention_location_n_filters,
|
||||
hparams.attention_location_kernel_size)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(
|
||||
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
|
||||
hparams.decoder_rnn_dim, 1)
|
||||
|
||||
self.linear_projection = LinearNorm(
|
||||
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
|
||||
hparams.n_mel_channels * hparams.n_frames_per_step)
|
||||
|
||||
self.gate_layer = LinearNorm(
|
||||
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
|
||||
bias=True, w_init_gain='sigmoid')
|
||||
|
||||
def get_go_frame(self, memory):
|
||||
""" Gets all zeros frames to use as first decoder input
|
||||
PARAMS
|
||||
------
|
||||
memory: decoder outputs
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
decoder_input: all zeros frames
|
||||
"""
|
||||
B = memory.size(0)
|
||||
decoder_input = Variable(memory.data.new(
|
||||
B, self.n_mel_channels * self.n_frames_per_step).zero_())
|
||||
return decoder_input
|
||||
|
||||
def initialize_decoder_states(self, memory, mask):
|
||||
""" Initializes attention rnn states, decoder rnn states, attention
|
||||
weights, attention cumulative weights, attention context, stores memory
|
||||
and stores processed memory
|
||||
PARAMS
|
||||
------
|
||||
memory: Encoder outputs
|
||||
mask: Mask for padded data if training, expects None for inference
|
||||
"""
|
||||
B = memory.size(0)
|
||||
MAX_TIME = memory.size(1)
|
||||
|
||||
self.attention_hidden = Variable(memory.data.new(
|
||||
B, self.attention_rnn_dim).zero_())
|
||||
self.attention_cell = Variable(memory.data.new(
|
||||
B, self.attention_rnn_dim).zero_())
|
||||
|
||||
self.decoder_hidden = Variable(memory.data.new(
|
||||
B, self.decoder_rnn_dim).zero_())
|
||||
self.decoder_cell = Variable(memory.data.new(
|
||||
B, self.decoder_rnn_dim).zero_())
|
||||
|
||||
self.attention_weights = Variable(memory.data.new(
|
||||
B, MAX_TIME).zero_())
|
||||
self.attention_weights_cum = Variable(memory.data.new(
|
||||
B, MAX_TIME).zero_())
|
||||
self.attention_context = Variable(memory.data.new(
|
||||
B, self.encoder_embedding_dim).zero_())
|
||||
|
||||
self.memory = memory
|
||||
self.processed_memory = self.attention_layer.memory_layer(memory)
|
||||
self.mask = mask
|
||||
|
||||
def parse_decoder_inputs(self, decoder_inputs):
|
||||
""" Prepares decoder inputs, i.e. mel outputs
|
||||
PARAMS
|
||||
------
|
||||
decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
inputs: processed decoder inputs
|
||||
|
||||
"""
|
||||
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
|
||||
decoder_inputs = decoder_inputs.transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs.view(
|
||||
decoder_inputs.size(0),
|
||||
int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
|
||||
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
|
||||
decoder_inputs = decoder_inputs.transpose(0, 1)
|
||||
return decoder_inputs
|
||||
|
||||
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
|
||||
""" Prepares decoder outputs for output
|
||||
PARAMS
|
||||
------
|
||||
mel_outputs:
|
||||
gate_outputs: gate output energies
|
||||
alignments:
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
mel_outputs:
|
||||
gate_outpust: gate output energies
|
||||
alignments:
|
||||
"""
|
||||
# (T_out, B) -> (B, T_out)
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
# (T_out, B) -> (B, T_out)
|
||||
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
|
||||
gate_outputs = gate_outputs.contiguous()
|
||||
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
|
||||
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
||||
# decouple frames per step
|
||||
mel_outputs = mel_outputs.view(
|
||||
mel_outputs.size(0), -1, self.n_mel_channels)
|
||||
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
|
||||
mel_outputs = mel_outputs.transpose(1, 2)
|
||||
|
||||
return mel_outputs, gate_outputs, alignments
|
||||
|
||||
def decode(self, decoder_input):
|
||||
""" Decoder step using stored states, attention and memory
|
||||
PARAMS
|
||||
------
|
||||
decoder_input: previous mel output
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
mel_output:
|
||||
gate_output: gate output energies
|
||||
attention_weights:
|
||||
"""
|
||||
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_hidden = F.dropout(
|
||||
self.attention_hidden, self.p_attention_dropout, self.training)
|
||||
|
||||
attention_weights_cat = torch.cat(
|
||||
(self.attention_weights.unsqueeze(1),
|
||||
self.attention_weights_cum.unsqueeze(1)), dim=1)
|
||||
self.attention_context, self.attention_weights = self.attention_layer(
|
||||
self.attention_hidden, self.memory, self.processed_memory,
|
||||
attention_weights_cat, self.mask)
|
||||
|
||||
self.attention_weights_cum += self.attention_weights
|
||||
decoder_input = torch.cat(
|
||||
(self.attention_hidden, self.attention_context), -1)
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(
|
||||
self.decoder_hidden, self.p_decoder_dropout, self.training)
|
||||
|
||||
decoder_hidden_attention_context = torch.cat(
|
||||
(self.decoder_hidden, self.attention_context), dim=1)
|
||||
decoder_output = self.linear_projection(
|
||||
decoder_hidden_attention_context)
|
||||
|
||||
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
|
||||
return decoder_output, gate_prediction, self.attention_weights
|
||||
|
||||
def forward(self, memory, decoder_inputs, memory_lengths):
|
||||
""" Decoder forward pass for training
|
||||
PARAMS
|
||||
------
|
||||
memory: Encoder outputs
|
||||
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
|
||||
memory_lengths: Encoder output lengths for attention masking.
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
mel_outputs: mel outputs from the decoder
|
||||
gate_outputs: gate outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
|
||||
decoder_input = self.get_go_frame(memory).unsqueeze(0)
|
||||
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
||||
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
||||
decoder_inputs = self.prenet(decoder_inputs)
|
||||
|
||||
self.initialize_decoder_states(
|
||||
memory, mask=~get_mask_from_lengths(memory_lengths))
|
||||
|
||||
mel_outputs, gate_outputs, alignments = [], [], []
|
||||
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||
mel_output, gate_output, attention_weights = self.decode(
|
||||
decoder_input)
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
gate_outputs += [gate_output.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
|
||||
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
|
||||
mel_outputs, gate_outputs, alignments)
|
||||
|
||||
return mel_outputs, gate_outputs, alignments
|
||||
|
||||
def inference(self, memory):
|
||||
""" Decoder inference
|
||||
PARAMS
|
||||
------
|
||||
memory: Encoder outputs
|
||||
|
||||
RETURNS
|
||||
-------
|
||||
mel_outputs: mel outputs from the decoder
|
||||
gate_outputs: gate outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
mel_outputs, gate_outputs, alignments = [], [], []
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
mel_output, gate_output, alignment = self.decode(decoder_input)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
gate_outputs += [gate_output]
|
||||
alignments += [alignment]
|
||||
|
||||
if torch.sigmoid(gate_output.data) > self.gate_threshold:
|
||||
break
|
||||
elif len(mel_outputs) == self.max_decoder_steps:
|
||||
print("Warning! Reached max decoder steps")
|
||||
break
|
||||
|
||||
decoder_input = mel_output
|
||||
|
||||
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
|
||||
mel_outputs, gate_outputs, alignments)
|
||||
|
||||
return mel_outputs, gate_outputs, alignments
|
||||
|
||||
|
||||
class Tacotron2(nn.Module):
|
||||
def __init__(self, hparams):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.mask_padding = hparams.mask_padding
|
||||
self.fp16_run = hparams.fp16_run
|
||||
self.n_mel_channels = hparams.n_mel_channels
|
||||
self.n_frames_per_step = hparams.n_frames_per_step
|
||||
self.embedding = nn.Embedding(
|
||||
hparams.n_symbols, hparams.symbols_embedding_dim)
|
||||
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
self.encoder = Encoder(hparams)
|
||||
self.decoder = Decoder(hparams)
|
||||
self.postnet = Postnet(hparams)
|
||||
|
||||
def parse_batch(self, batch):
|
||||
text_padded, input_lengths, mel_padded, gate_padded, \
|
||||
output_lengths = batch
|
||||
text_padded = to_gpu(text_padded).long()
|
||||
input_lengths = to_gpu(input_lengths).long()
|
||||
max_len = torch.max(input_lengths.data).item()
|
||||
mel_padded = to_gpu(mel_padded).float()
|
||||
gate_padded = to_gpu(gate_padded).float()
|
||||
output_lengths = to_gpu(output_lengths).long()
|
||||
|
||||
return (
|
||||
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
||||
(mel_padded, gate_padded))
|
||||
|
||||
def parse_output(self, outputs, output_lengths=None):
|
||||
if self.mask_padding and output_lengths is not None:
|
||||
mask = ~get_mask_from_lengths(output_lengths)
|
||||
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
|
||||
mask = mask.permute(1, 0, 2)
|
||||
|
||||
outputs[0].data.masked_fill_(mask, 0.0)
|
||||
outputs[1].data.masked_fill_(mask, 0.0)
|
||||
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
|
||||
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs):
|
||||
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
|
||||
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
||||
|
||||
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
||||
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
|
||||
mel_outputs, gate_outputs, alignments = self.decoder(
|
||||
encoder_outputs, mels, memory_lengths=text_lengths)
|
||||
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
|
||||
output_lengths)
|
||||
|
||||
def inference(self, inputs):
|
||||
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
||||
outputs = self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@register_model
|
||||
def register_nv_tacotron2(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
hparams = create_hparams()
|
||||
hparams.update(kw)
|
||||
hparams = munchify(hparams)
|
||||
return Tacotron2(hparams)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tron = register_nv_tacotron2({}, {})
|
||||
inputs = torch.randint(high=24, size=(1,12)), torch.tensor([12]), torch.randn((1,80,749)), 800, torch.tensor([749])
|
||||
out = tron(inputs)
|
||||
print(out)
|
19
codes/models/tacotron2/text/LICENSE
Normal file
19
codes/models/tacotron2/text/LICENSE
Normal file
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2017 Keith Ito
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
74
codes/models/tacotron2/text/__init__.py
Normal file
74
codes/models/tacotron2/text/__init__.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
import re
|
||||
from models.tacotron2.text import cleaners
|
||||
from models.tacotron2.text.symbols import symbols
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _symbols_to_sequence(symbols):
|
||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s is not '_' and s is not '~'
|
90
codes/models/tacotron2/text/cleaners.py
Normal file
90
codes/models/tacotron2/text/cleaners.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
'''
|
||||
Cleaners are transformations that run over the input text at both training and eval time.
|
||||
|
||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||
1. "english_cleaners" for English text
|
||||
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
'''
|
||||
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
from .numbers import normalize_numbers
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
65
codes/models/tacotron2/text/cmudict.py
Normal file
65
codes/models/tacotron2/text/cmudict.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
import re
|
||||
|
||||
|
||||
valid_symbols = [
|
||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
||||
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
||||
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
||||
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
||||
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
||||
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
||||
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
||||
]
|
||||
|
||||
_valid_symbol_set = set(valid_symbols)
|
||||
|
||||
|
||||
class CMUDict:
|
||||
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding='latin-1') as f:
|
||||
entries = _parse_cmudict(f)
|
||||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
if not keep_ambiguous:
|
||||
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
||||
self._entries = entries
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self._entries)
|
||||
|
||||
|
||||
def lookup(self, word):
|
||||
'''Returns list of ARPAbet pronunciations of the given word.'''
|
||||
return self._entries.get(word.upper())
|
||||
|
||||
|
||||
|
||||
_alt_re = re.compile(r'\([0-9]+\)')
|
||||
|
||||
|
||||
def _parse_cmudict(file):
|
||||
cmudict = {}
|
||||
for line in file:
|
||||
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
||||
parts = line.split(' ')
|
||||
word = re.sub(_alt_re, '', parts[0])
|
||||
pronunciation = _get_pronunciation(parts[1])
|
||||
if pronunciation:
|
||||
if word in cmudict:
|
||||
cmudict[word].append(pronunciation)
|
||||
else:
|
||||
cmudict[word] = [pronunciation]
|
||||
return cmudict
|
||||
|
||||
|
||||
def _get_pronunciation(s):
|
||||
parts = s.strip().split(' ')
|
||||
for part in parts:
|
||||
if part not in _valid_symbol_set:
|
||||
return None
|
||||
return ' '.join(parts)
|
71
codes/models/tacotron2/text/numbers.py
Normal file
71
codes/models/tacotron2/text/numbers.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
import inflect
|
||||
import re
|
||||
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
||||
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
_number_re = re.compile(r'[0-9]+')
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace('.', ' point ')
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars' # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
18
codes/models/tacotron2/text/symbols.py
Normal file
18
codes/models/tacotron2/text/symbols.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
'''
|
||||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
||||
from models.tacotron2.text import cmudict
|
||||
|
||||
_pad = '_'
|
||||
_punctuation = '!\'(),.:;? '
|
||||
_special = '-'
|
||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ['@' + s for s in cmudict.valid_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
|
|
@ -1,5 +1,5 @@
|
|||
# Fundamentals
|
||||
numpy
|
||||
opencv-python
|
||||
pyyaml
|
||||
tb-nightly
|
||||
future
|
||||
|
@ -11,12 +11,20 @@ munch
|
|||
tqdm
|
||||
scp
|
||||
tensorboard
|
||||
pytorch_fid==0.1.1
|
||||
kornia
|
||||
linear_attention_transformer
|
||||
vector_quantize_pytorch
|
||||
orjson
|
||||
einops
|
||||
gsa-pytorch
|
||||
lambda-networks
|
||||
|
||||
# For image generation stuff
|
||||
opencv-python
|
||||
kornia
|
||||
pytorch_ssim
|
||||
gsa-pytorch
|
||||
vector_quantize_pytorch
|
||||
pytorch_fid==0.1.1
|
||||
|
||||
# For audio generation stuff
|
||||
inflect==0.2.5
|
||||
librosa==0.6.0
|
||||
Unidecode==1.0.22
|
|
@ -107,7 +107,7 @@ class Trainer:
|
|||
dataset_ratio = 1 # enlarge the size of each epoch
|
||||
for phase, dataset_opt in opt['datasets'].items():
|
||||
if phase == 'train':
|
||||
self.train_set = create_dataset(dataset_opt)
|
||||
self.train_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
||||
total_iters = int(opt['train']['niter'])
|
||||
self.total_epochs = int(math.ceil(total_iters / train_size))
|
||||
|
@ -116,15 +116,15 @@ class Trainer:
|
|||
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||
else:
|
||||
self.train_sampler = None
|
||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler)
|
||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn)
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(self.train_set), train_size))
|
||||
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
||||
self.total_epochs, total_iters))
|
||||
elif phase == 'val':
|
||||
self.val_set = create_dataset(dataset_opt)
|
||||
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None)
|
||||
self.val_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None, collate_fn=collate_fn)
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Number of val images in [{:s}]: {:d}'.format(
|
||||
dataset_opt['name'], len(self.val_set)))
|
||||
|
|
|
@ -56,6 +56,9 @@ def create_loss(opt_loss, env):
|
|||
elif type == 'switch_transformer_balance':
|
||||
from models.switched_conv.mixture_of_experts import SwitchTransformersLoadBalancingLoss
|
||||
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
||||
elif type == 'nv_tacotron2_loss':
|
||||
from models.tacotron2.loss import Tacotron2Loss
|
||||
return Tacotron2Loss()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user