diff --git a/codes/data/audio/gpt_tts_dataset.py b/codes/data/audio/gpt_tts_dataset.py index d871995b..7d1aff98 100644 --- a/codes/data/audio/gpt_tts_dataset.py +++ b/codes/data/audio/gpt_tts_dataset.py @@ -1,24 +1,22 @@ import os -import random -import numpy as np + import torch +import torch.nn.functional as F import torch.utils.data from torch import LongTensor from tqdm import tqdm -import models.tacotron2.layers as layers -from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text - -from models.tacotron2.text import text_to_sequence -from utils.util import opt_get +from models.tacotron2.taco_utils import load_filepaths_and_text from models.tacotron2.text import symbols -import torch.nn.functional as F +from models.tacotron2.text import text_to_sequence class GptTtsDataset(torch.utils.data.Dataset): - NUMBER_SYMBOLS = len(symbols)+3 - TEXT_START_TOKEN = LongTensor([NUMBER_SYMBOLS-3]) - TEXT_STOP_TOKEN = LongTensor([NUMBER_SYMBOLS-2]) + MAX_SYMBOLS_PER_PHRASE = 200 + NUMBER_SYMBOLS = len(symbols) + NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2 + TEXT_START_TOKEN = LongTensor([NUMBER_TEXT_TOKENS-1]) + TEXT_STOP_TOKEN = LongTensor([NUMBER_TEXT_TOKENS-2]) def __init__(self, opt): self.path = os.path.dirname(opt['path']) @@ -49,11 +47,11 @@ class GptTtsDataset(torch.utils.data.Dataset): class GptTtsCollater(): - NUMBER_SYMBOLS = len(symbols)+3 - TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1 + MAX_SYMBOLS_PER_PHRASE = 200 + NUMBER_SYMBOLS = len(symbols) + NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2 def __init__(self, opt): - self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3 self.MEL_PAD_TOKEN = self.MEL_DICTIONARY_SIZE-1 @@ -64,9 +62,13 @@ class GptTtsCollater(): max_mel_len = max(mel_lens) texts = [] qmels = [] + # This is the sequential "background" tokens that are used as padding for text tokens, as specified in the DALLE paper. + text_range_embedding = torch.arange(max_text_len) + self.NUMBER_SYMBOLS for b in batch: text, qmel, _ = b - texts.append(F.pad(text, (0, max_text_len-len(text)), value=self.TEXT_PAD_TOKEN)) + text = F.pad(text, (0, max_text_len-len(text)), value=0) + text = torch.where(text == 0, text_range_embedding, text) + texts.append(text) qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN)) filenames = [j[2] for j in batch] diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index b7508bc3..5db11c8b 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -1,48 +1,37 @@ import torch import torch.nn as nn import torch.nn.functional as F -from munch import munchify -from torch import LongTensor -from tqdm import tqdm -from models.arch_util import ConvGnSilu -from models.gpt_voice.pixelshuffle_1d import PixelUnshuffle1D, PixelShuffle1D -from models.tacotron2 import hparams +from models.gpt_voice.lucidrains_gpt import Transformer from models.tacotron2.taco_utils import get_mask_from_lengths -from models.tacotron2.tacotron2 import Postnet from models.tacotron2.text import symbols -from models.gpt_voice.min_gpt import GPT, GPTConfig from trainer.networks import register_model class GptTts(nn.Module): - NUMBER_SYMBOLS = len(symbols)+3 - TEXT_START_TOKEN = NUMBER_SYMBOLS-3 - TEXT_STOP_TOKEN = NUMBER_SYMBOLS-2 - TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1 + MAX_SYMBOLS_PER_PHRASE = 200 + NUMBER_SYMBOLS = len(symbols) + NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2 MEL_DICTIONARY_SIZE = 512+3 MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3 MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2 - MEL_PAD_TOKEN = MEL_DICTIONARY_SIZE-1 def __init__(self): super().__init__() model_dim = 512 - max_symbols_per_phrase = 200 - max_mel_frames = 900 * 3 // 8 # The VQVAE outputs 3/8 of the input mel as tokens. - mel_dim=80 + max_mel_frames = 900 * 3 // 8 # 900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens. self.model_dim = model_dim self.max_mel_frames = max_mel_frames - self.text_embedding = nn.Embedding(self.NUMBER_SYMBOLS, model_dim) + self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim) - # *_tags are additively applied to - self.text_pos_embedding = nn.Embedding(max_symbols_per_phrase, model_dim) + self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim) self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim) - self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False) + #self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False) + self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8) self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.NUMBER_SYMBOLS) + self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE) def forward(self, text_inputs, text_lengths, mel_targets, output_lengths): @@ -55,8 +44,8 @@ class GptTts(nn.Module): # Compute logits for text and mel heads text_logits = self.final_norm(enc[:, :text_emb.shape[1]]) - text_logits = self.text_head(text_logits) mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) + text_logits = self.text_head(text_logits) mel_logits = self.mel_head(mel_logits) # Compute loss @@ -67,25 +56,18 @@ class GptTts(nn.Module): mel_logits = mel_logits.permute(0,2,1)[:,:,:-1] loss_mel = F.cross_entropy(mel_logits, mel_targets, reduction='none') - # Apply a reduction factor across MEL_PAD and TEXT_PAD tokens. - pad_loss_reduction_factor = .01 - text_pad_mask = ~get_mask_from_lengths(text_lengths-1, text_inputs.shape[1]-1) # -1 to strip off , which is accounted for in text_lengths and output_lengths. - mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1]) - loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask, pad_loss_reduction_factor) - loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask, pad_loss_reduction_factor) - # Fix up mel_logits so it can go into a VAE decoder as well. mel_codes = torch.argmax(F.softmax(mel_logits, dim=1), dim=1) + mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1]) mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask, 0) - mel_codes = mel_codes[:,: - - - - -1] # Strip off token too (or padding). The important part is that the output sequence length is identical to the VAE input. + mel_codes = mel_codes[:,:-1] # Strip off token too (or padding). The important part is that the output sequence length is identical to the VAE input. extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD mel_codes = mel_codes * extra_mask - return loss_text.mean(), loss_mel.mean(), mel_codes + # This class also returns the mel_targets for validation purposes. Format those. + mel_targets = mel_targets[:,:-1] + mel_targets = mel_targets * (mel_targets < self.MEL_DICTIONARY_SIZE-3) + return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets def inference(self, text_inputs): text_emb = self.text_embedding(text_inputs) diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py new file mode 100644 index 00000000..3aea6f6d --- /dev/null +++ b/codes/models/gpt_voice/lucidrains_gpt.py @@ -0,0 +1,193 @@ +from inspect import isfunction + +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange + +# helpers +from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence +from utils.util import sequential_checkpoint + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, depth = 1): + if isinstance(val, list): + val = tuple(val) + return val if isinstance(val, tuple) else (val,) * depth + + +class DivideMax(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + maxes = x.amax(dim = self.dim, keepdim = True) + return x / maxes + + +# https://arxiv.org/abs/2103.17239 +class LayerScale(nn.Module): + def __init__(self, dim, depth, fn): + super().__init__() + if depth <= 18: + init_eps = 0.1 + elif depth > 18 and depth <= 24: + init_eps = 1e-5 + else: + init_eps = 1e-6 + + scale = torch.zeros(1, 1, dim).fill_(init_eps) + self.scale = nn.Parameter(scale) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class GEGLU(nn.Module): + def forward(self, x): + x, gates = x.chunk(2, dim = -1) + return x * F.gelu(gates) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout = 0., mult = 4.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult * 2), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim) + ) + + def forward(self, x): + return self.net(x) + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def stable_softmax(t, dim = -1, alpha = 32 ** 2): + t = t / alpha + t = t - torch.amax(t, dim = dim, keepdim = True) + return (t * alpha).softmax(dim = dim) + + +# classes +class Attention(nn.Module): + def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.seq_len = seq_len + self.scale = dim_head ** -0.5 + + self.stable = stable + self.causal = causal + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None): + b, n, _, h, device = *x.shape, self.heads, x.device + softmax = torch.softmax if not self.stable else stable_softmax + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + q = q * self.scale + + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) + mask_value = max_neg_value(dots) + + if exists(mask): + mask = rearrange(mask, 'b j -> b () () j') + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + dots.masked_fill_(mask, mask_value) + + attn = softmax(dots, dim=-1) + + out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + + +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + seq_len, + reversible = False, + causal = True, + heads = 8, + dim_head = 64, + ff_mult = 4, + attn_dropout = 0., + ff_dropout = 0., + sparse_attn = False, + stable = False + ): + super().__init__() + layers = nn.ModuleList([]) + sparse_layer = cast_tuple(sparse_attn, depth) + + for ind, sparse_attn in zip(range(depth), sparse_layer): + attn = Attention(dim, stable=stable, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) + + layers.append(nn.ModuleList([ + LayerScale(dim, ind + 1, PreNorm(dim, attn)), + LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))) + ])) + + execute_type = ReversibleSequence if reversible else SequentialSequence + route_attn = ((True, False),) * depth + attn_route_map = {'mask': route_attn} + + self.layers = execute_type(layers, args_route = attn_route_map) + + def forward(self, x): + return self.layers(x) diff --git a/codes/models/gpt_voice/reversible.py b/codes/models/gpt_voice/reversible.py new file mode 100644 index 00000000..97a3dd64 --- /dev/null +++ b/codes/models/gpt_voice/reversible.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +from torch.autograd.function import Function +from torch.utils.checkpoint import get_device_states, set_device_states + +# for routing arguments into the functions of the reversible layer +def route_args(router, args, depth): + routed_args = [(dict(), dict()) for _ in range(depth)] + matched_keys = [key for key in args.keys() if key in router] + + for key in matched_keys: + val = args[key] + for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): + new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) + routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) + return routed_args + +# following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html +class Deterministic(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.cpu_state = None + self.cuda_in_fwd = None + self.gpu_devices = None + self.gpu_states = None + + def record_rng(self, *args): + self.cpu_state = torch.get_rng_state() + if torch.cuda._initialized: + self.cuda_in_fwd = True + self.gpu_devices, self.gpu_states = get_device_states(*args) + + def forward(self, *args, record_rng = False, set_rng = False, **kwargs): + if record_rng: + self.record_rng(*args) + + if not set_rng: + return self.net(*args, **kwargs) + + rng_devices = [] + if self.cuda_in_fwd: + rng_devices = self.gpu_devices + + with torch.random.fork_rng(devices=rng_devices, enabled=True): + torch.set_rng_state(self.cpu_state) + if self.cuda_in_fwd: + set_device_states(self.gpu_devices, self.gpu_states) + return self.net(*args, **kwargs) + +# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py +# once multi-GPU is confirmed working, refactor and send PR back to source +class ReversibleBlock(nn.Module): + def __init__(self, f, g): + super().__init__() + self.f = Deterministic(f) + self.g = Deterministic(g) + + def forward(self, x, f_args = {}, g_args = {}): + x1, x2 = torch.chunk(x, 2, dim=2) + y1, y2 = None, None + + with torch.no_grad(): + y1 = x1 + self.f(x2, record_rng=self.training, **f_args) + y2 = x2 + self.g(y1, record_rng=self.training, **g_args) + + return torch.cat([y1, y2], dim=2) + + def backward_pass(self, y, dy, f_args = {}, g_args = {}): + y1, y2 = torch.chunk(y, 2, dim=2) + del y + + dy1, dy2 = torch.chunk(dy, 2, dim=2) + del dy + + with torch.enable_grad(): + y1.requires_grad = True + gy1 = self.g(y1, set_rng=True, **g_args) + torch.autograd.backward(gy1, dy2) + + with torch.no_grad(): + x2 = y2 - gy1 + del y2, gy1 + + dx1 = dy1 + y1.grad + del dy1 + y1.grad = None + + with torch.enable_grad(): + x2.requires_grad = True + fx2 = self.f(x2, set_rng=True, **f_args) + torch.autograd.backward(fx2, dx1, retain_graph=True) + + with torch.no_grad(): + x1 = y1 - fx2 + del y1, fx2 + + dx2 = dy2 + x2.grad + del dy2 + x2.grad = None + + x = torch.cat([x1, x2.detach()], dim=2) + dx = torch.cat([dx1, dx2], dim=2) + + return x, dx + +class _ReversibleFunction(Function): + @staticmethod + def forward(ctx, x, blocks, args): + ctx.args = args + for block, kwarg in zip(blocks, args): + x = block(x, **kwarg) + ctx.y = x.detach() + ctx.blocks = blocks + return x + + @staticmethod + def backward(ctx, dy): + y = ctx.y + args = ctx.args + for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): + y, dy = block.backward_pass(y, dy, **kwargs) + return dy, None, None + +class SequentialSequence(nn.Module): + def __init__(self, layers, args_route = {}, layer_dropout = 0.): + super().__init__() + assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' + self.layers = layers + self.args_route = args_route + self.layer_dropout = layer_dropout + + def forward(self, x, **kwargs): + args = route_args(self.args_route, kwargs, len(self.layers)) + layers_and_args = list(zip(self.layers, args)) + + for (f, g), (f_args, g_args) in layers_and_args: + x = x + f(x, **f_args) + x = x + g(x, **g_args) + return x + +class ReversibleSequence(nn.Module): + def __init__(self, blocks, args_route = {}): + super().__init__() + self.args_route = args_route + self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) + + def forward(self, x, **kwargs): + x = torch.cat([x, x], dim=-1) + + blocks = self.blocks + args = route_args(self.args_route, kwargs, len(blocks)) + args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) + + out = _ReversibleFunction.apply(x, blocks, args) + return torch.stack(out.chunk(2, dim=-1)).mean(dim=0) \ No newline at end of file diff --git a/codes/trainer/lr_scheduler.py b/codes/trainer/lr_scheduler.py index 151ea1c4..c58e294f 100644 --- a/codes/trainer/lr_scheduler.py +++ b/codes/trainer/lr_scheduler.py @@ -18,7 +18,8 @@ def get_scheduler_for_name(name, optimizers, scheduler_opt): weights=scheduler_opt['restart_weights'], gamma=scheduler_opt['lr_gamma'], clear_state=scheduler_opt['clear_state'], - force_lr=scheduler_opt['force_lr']) + force_lr=scheduler_opt['force_lr'], + warmup_steps=scheduler_opt['warmup_steps']) elif name == 'ProgressiveMultiStepLR': sched = ProgressiveMultiStepLR(o, scheduler_opt['gen_lr_steps'], scheduler_opt['progressive_starts'], @@ -55,7 +56,7 @@ class ProgressiveMultiStepLR(_LRScheduler): class MultiStepLR_Restart(_LRScheduler): def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, - clear_state=False, force_lr=False, last_epoch=-1): + clear_state=False, force_lr=False, last_epoch=-1, warmup_steps=0): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state @@ -63,11 +64,13 @@ class MultiStepLR_Restart(_LRScheduler): self.restarts = [v + 1 for v in self.restarts] self.restart_weights = weights if weights else [1] self.force_lr = force_lr + self.warmup_steps = warmup_steps assert len(self.restarts) == len( self.restart_weights), 'restarts and their weights do not match.' super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) def get_lr(self): + # Note to self: for the purposes of this trainer, "last_epoch" should read "last_step" if self.force_lr: return [group['initial_lr'] for group in self.optimizer.param_groups] if self.last_epoch in self.restarts: @@ -75,6 +78,9 @@ class MultiStepLR_Restart(_LRScheduler): self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch < self.warmup_steps: + factor = 1 - (self.warmup_steps - self.last_epoch) / self.warmup_steps + return [group['initial_lr'] * factor for group in self.optimizer.param_groups] if self.last_epoch not in self.milestones: return [group['lr'] for group in self.optimizer.param_groups] return [ @@ -148,8 +154,8 @@ if __name__ == "__main__": restart_weights = [1, 1, 1] scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, - clear_state=False) - + clear_state=False, warmup_steps=20000) + ''' ############################## # Cosine Annealing Restart ############################## @@ -165,11 +171,12 @@ if __name__ == "__main__": scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=10000, eta_min=1e-8, restarts=restarts, weights=restart_weights) + ''' ############################## # Draw figure ############################## - N_iter = 500000 + N_iter = 100000 lr_l = list(range(N_iter)) for i in range(N_iter): scheduler.step()