It works!

This commit is contained in:
James Betker 2021-08-04 20:07:45 -06:00
parent 36c7c1fbdb
commit 341f28dd82
5 changed files with 395 additions and 55 deletions

View File

@ -1,24 +1,22 @@
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import LongTensor
from tqdm import tqdm
import models.tacotron2.layers as layers
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
from models.tacotron2.text import text_to_sequence
from utils.util import opt_get
from models.tacotron2.taco_utils import load_filepaths_and_text
from models.tacotron2.text import symbols
import torch.nn.functional as F
from models.tacotron2.text import text_to_sequence
class GptTtsDataset(torch.utils.data.Dataset):
NUMBER_SYMBOLS = len(symbols)+3
TEXT_START_TOKEN = LongTensor([NUMBER_SYMBOLS-3])
TEXT_STOP_TOKEN = LongTensor([NUMBER_SYMBOLS-2])
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
TEXT_START_TOKEN = LongTensor([NUMBER_TEXT_TOKENS-1])
TEXT_STOP_TOKEN = LongTensor([NUMBER_TEXT_TOKENS-2])
def __init__(self, opt):
self.path = os.path.dirname(opt['path'])
@ -49,11 +47,11 @@ class GptTtsDataset(torch.utils.data.Dataset):
class GptTtsCollater():
NUMBER_SYMBOLS = len(symbols)+3
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
def __init__(self, opt):
self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3
self.MEL_PAD_TOKEN = self.MEL_DICTIONARY_SIZE-1
@ -64,9 +62,13 @@ class GptTtsCollater():
max_mel_len = max(mel_lens)
texts = []
qmels = []
# This is the sequential "background" tokens that are used as padding for text tokens, as specified in the DALLE paper.
text_range_embedding = torch.arange(max_text_len) + self.NUMBER_SYMBOLS
for b in batch:
text, qmel, _ = b
texts.append(F.pad(text, (0, max_text_len-len(text)), value=self.TEXT_PAD_TOKEN))
text = F.pad(text, (0, max_text_len-len(text)), value=0)
text = torch.where(text == 0, text_range_embedding, text)
texts.append(text)
qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN))
filenames = [j[2] for j in batch]

View File

@ -1,48 +1,37 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from torch import LongTensor
from tqdm import tqdm
from models.arch_util import ConvGnSilu
from models.gpt_voice.pixelshuffle_1d import PixelUnshuffle1D, PixelShuffle1D
from models.tacotron2 import hparams
from models.gpt_voice.lucidrains_gpt import Transformer
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.tacotron2 import Postnet
from models.tacotron2.text import symbols
from models.gpt_voice.min_gpt import GPT, GPTConfig
from trainer.networks import register_model
class GptTts(nn.Module):
NUMBER_SYMBOLS = len(symbols)+3
TEXT_START_TOKEN = NUMBER_SYMBOLS-3
TEXT_STOP_TOKEN = NUMBER_SYMBOLS-2
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
MEL_DICTIONARY_SIZE = 512+3
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
MEL_PAD_TOKEN = MEL_DICTIONARY_SIZE-1
def __init__(self):
super().__init__()
model_dim = 512
max_symbols_per_phrase = 200
max_mel_frames = 900 * 3 // 8 # The VQVAE outputs 3/8 of the input mel as tokens.
mel_dim=80
max_mel_frames = 900 * 3 // 8 # 900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens.
self.model_dim = model_dim
self.max_mel_frames = max_mel_frames
self.text_embedding = nn.Embedding(self.NUMBER_SYMBOLS, model_dim)
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
# *_tags are additively applied to
self.text_pos_embedding = nn.Embedding(max_symbols_per_phrase, model_dim)
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
#self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_SYMBOLS)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
@ -55,8 +44,8 @@ class GptTts(nn.Module):
# Compute logits for text and mel heads
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
text_logits = self.text_head(text_logits)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
text_logits = self.text_head(text_logits)
mel_logits = self.mel_head(mel_logits)
# Compute loss
@ -67,25 +56,18 @@ class GptTts(nn.Module):
mel_logits = mel_logits.permute(0,2,1)[:,:,:-1]
loss_mel = F.cross_entropy(mel_logits, mel_targets, reduction='none')
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
pad_loss_reduction_factor = .01
text_pad_mask = ~get_mask_from_lengths(text_lengths-1, text_inputs.shape[1]-1) # -1 to strip off <BOS>, which is accounted for in text_lengths and output_lengths.
mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1])
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask, pad_loss_reduction_factor)
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask, pad_loss_reduction_factor)
# Fix up mel_logits so it can go into a VAE decoder as well.
mel_codes = torch.argmax(F.softmax(mel_logits, dim=1), dim=1)
mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1])
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask, 0)
mel_codes = mel_codes[:,:
-1] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
mel_codes = mel_codes[:,:-1] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
return loss_text.mean(), loss_mel.mean(), mel_codes
# This class also returns the mel_targets for validation purposes. Format those.
mel_targets = mel_targets[:,:-1]
mel_targets = mel_targets * (mel_targets < self.MEL_DICTIONARY_SIZE-3)
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
def inference(self, text_inputs):
text_emb = self.text_embedding(text_inputs)

View File

@ -0,0 +1,193 @@
from inspect import isfunction
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# helpers
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
from utils.util import sequential_checkpoint
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
class DivideMax(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
maxes = x.amax(dim = self.dim, keepdim = True)
return x / maxes
# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True)
return (t * alpha).softmax(dim = dim)
# classes
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head ** -0.5
self.stable = stable
self.causal = causal
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = q * self.scale
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)
if exists(mask):
mask = rearrange(mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
attn = softmax(dots, dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
seq_len,
reversible = False,
causal = True,
heads = 8,
dim_head = 64,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
sparse_attn = False,
stable = False
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)
for ind, sparse_attn in zip(range(depth), sparse_layer):
attn = Attention(dim, stable=stable, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
]))
execute_type = ReversibleSequence if reversible else SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn}
self.layers = execute_type(layers, args_route = attn_route_map)
def forward(self, x):
return self.layers(x)

View File

@ -0,0 +1,156 @@
import torch
import torch.nn as nn
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
# for routing arguments into the functions of the reversible layer
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
# following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# once multi-GPU is confirmed working, refactor and send PR back to source
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args = {}, g_args = {}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=2)
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
y1, y2 = torch.chunk(y, 2, dim=2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
return x, dx
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, args):
ctx.args = args
for block, kwarg in zip(blocks, args):
x = block(x, **kwarg)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
def backward(ctx, dy):
y = ctx.y
args = ctx.args
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
y, dy = block.backward_pass(y, dy, **kwargs)
return dy, None, None
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
super().__init__()
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
def forward(self, x, **kwargs):
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
for (f, g), (f_args, g_args) in layers_and_args:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
return x
class ReversibleSequence(nn.Module):
def __init__(self, blocks, args_route = {}):
super().__init__()
self.args_route = args_route
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
def forward(self, x, **kwargs):
x = torch.cat([x, x], dim=-1)
blocks = self.blocks
args = route_args(self.args_route, kwargs, len(blocks))
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
out = _ReversibleFunction.apply(x, blocks, args)
return torch.stack(out.chunk(2, dim=-1)).mean(dim=0)

View File

@ -18,7 +18,8 @@ def get_scheduler_for_name(name, optimizers, scheduler_opt):
weights=scheduler_opt['restart_weights'],
gamma=scheduler_opt['lr_gamma'],
clear_state=scheduler_opt['clear_state'],
force_lr=scheduler_opt['force_lr'])
force_lr=scheduler_opt['force_lr'],
warmup_steps=scheduler_opt['warmup_steps'])
elif name == 'ProgressiveMultiStepLR':
sched = ProgressiveMultiStepLR(o, scheduler_opt['gen_lr_steps'],
scheduler_opt['progressive_starts'],
@ -55,7 +56,7 @@ class ProgressiveMultiStepLR(_LRScheduler):
class MultiStepLR_Restart(_LRScheduler):
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
clear_state=False, force_lr=False, last_epoch=-1):
clear_state=False, force_lr=False, last_epoch=-1, warmup_steps=0):
self.milestones = Counter(milestones)
self.gamma = gamma
self.clear_state = clear_state
@ -63,11 +64,13 @@ class MultiStepLR_Restart(_LRScheduler):
self.restarts = [v + 1 for v in self.restarts]
self.restart_weights = weights if weights else [1]
self.force_lr = force_lr
self.warmup_steps = warmup_steps
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
# Note to self: for the purposes of this trainer, "last_epoch" should read "last_step"
if self.force_lr:
return [group['initial_lr'] for group in self.optimizer.param_groups]
if self.last_epoch in self.restarts:
@ -75,6 +78,9 @@ class MultiStepLR_Restart(_LRScheduler):
self.optimizer.state = defaultdict(dict)
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
factor = 1 - (self.warmup_steps - self.last_epoch) / self.warmup_steps
return [group['initial_lr'] * factor for group in self.optimizer.param_groups]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [
@ -148,8 +154,8 @@ if __name__ == "__main__":
restart_weights = [1, 1, 1]
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
clear_state=False)
clear_state=False, warmup_steps=20000)
'''
##############################
# Cosine Annealing Restart
##############################
@ -165,11 +171,12 @@ if __name__ == "__main__":
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=10000, eta_min=1e-8, restarts=restarts,
weights=restart_weights)
'''
##############################
# Draw figure
##############################
N_iter = 500000
N_iter = 100000
lr_l = list(range(N_iter))
for i in range(N_iter):
scheduler.step()