This commit is contained in:
James Betker 2021-08-13 15:02:18 -06:00
parent 81e91c99de
commit cdee31c60b
4 changed files with 122 additions and 11 deletions

View File

@ -0,0 +1,107 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from models.gpt_voice.lucidrains_gpt import Transformer
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.BatchNorm1d(chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.BatchNorm1d(chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
ResBlock(channels//4),
ResBlock(channels//4),
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(channels//2),
nn.ReLU(),
ResBlock(channels//2),
ResBlock(channels//2),
ResBlock(channels//2),
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
ResBlock(channels),
ResBlock(channels),
ResBlock(channels)
)
def forward(self, x):
return self.encoder(x)
class GptAsr(nn.Module):
MAX_SYMBOLS_PER_PHRASE = 200
MAX_MEL_FRAMES = 1000 // 4
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS
def __init__(self, layers=8, model_dim=512, heads=8):
super().__init__()
self.model_dim = model_dim
self.max_mel_frames = self.MAX_MEL_FRAMES
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.mel_encoder = MelEncoder(model_dim)
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim)
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+self.MAX_MEL_FRAMES, heads=heads,
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
def forward(self, mel_inputs, text_targets):
text_targets = F.pad(text_targets, (0, self.MAX_SYMBOLS_PER_PHRASE-text_targets.shape[1]))
text_emb = self.text_embedding(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1)
enc = self.gpt(emb)
# Compute loss
text_logits = self.final_norm(enc[:, self.MAX_MEL_FRAMES:])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
loss_text = F.cross_entropy(text_logits[:,:,1:], text_targets[:,:-1].long())
return loss_text.mean()
@register_model
def register_gpt_asr(opt_net, opt):
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
gpt = GptAsr()
l = gpt(torch.randn(2,80,800),
torch.randint(high=len(symbols), size=(2,180)))
print(l.shape)
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)

View File

@ -108,7 +108,7 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
# classes # classes
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
self.heads = heads self.heads = heads
@ -116,7 +116,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.stable = stable self.stable = stable
self.causal = causal self.non_causal_sequence_partition = non_causal_sequence_partition
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
@ -141,10 +141,14 @@ class Attention(nn.Module):
dots.masked_fill_(~mask, mask_value) dots.masked_fill_(~mask, mask_value)
del mask del mask
if self.causal:
i, j = dots.shape[-2:] i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() mask = torch.ones(i, j, device = device).triu_(j - i + 1)
dots.masked_fill_(mask, mask_value) if self.non_causal_sequence_partition > 0:
non_causal_mask = torch.ones((i, j), device=device)
non_causal_mask[:, :self.non_causal_sequence_partition] = 0
mask = mask * non_causal_mask
dots.masked_fill_(mask.bool(), mask_value)
attn = softmax(dots, dim=-1) attn = softmax(dots, dim=-1)
@ -162,21 +166,21 @@ class Transformer(nn.Module):
depth, depth,
seq_len, seq_len,
reversible = False, reversible = False,
causal = True,
heads = 8, heads = 8,
dim_head = 64, dim_head = 64,
ff_mult = 4, ff_mult = 4,
attn_dropout = 0., attn_dropout = 0.,
ff_dropout = 0., ff_dropout = 0.,
sparse_attn = False, sparse_attn = False,
stable = False stable = False,
non_causal_sequence_partition=0,
): ):
super().__init__() super().__init__()
layers = nn.ModuleList([]) layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth) sparse_layer = cast_tuple(sparse_attn, depth)
for ind, sparse_attn in zip(range(depth), sparse_layer): for ind, sparse_attn in zip(range(depth), sparse_layer):
attn = Attention(dim, stable=stable, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) attn = Attention(dim, stable=stable, non_causal_sequence_partition = non_causal_sequence_partition, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
layers.append(nn.ModuleList([ layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn)), LayerScale(dim, ind + 1, PreNorm(dim, attn)),

View File

@ -282,7 +282,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_mozcv.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mozcv.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -60,7 +60,7 @@ class GeneratorInjector(Injector):
results = method(*params) results = method(*params)
new_state = {} new_state = {}
if isinstance(self.output, list): if isinstance(self.output, list):
# Only dereference tuples or lists, not tensors. # Only dereference tuples or lists, not tensors. IF YOU REACH THIS ERROR, REMOVE THE BRACES AROUND YOUR OUTPUTS IN THE YAML CONFIG
assert isinstance(results, list) or isinstance(results, tuple) assert isinstance(results, list) or isinstance(results, tuple)
for i, k in enumerate(self.output): for i, k in enumerate(self.output):
new_state[k] = results[i] new_state[k] = results[i]