This commit is contained in:
James Betker 2021-08-13 15:02:18 -06:00
parent 81e91c99de
commit cdee31c60b
4 changed files with 122 additions and 11 deletions

View File

@ -0,0 +1,107 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from models.gpt_voice.lucidrains_gpt import Transformer
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.BatchNorm1d(chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.BatchNorm1d(chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
ResBlock(channels//4),
ResBlock(channels//4),
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(channels//2),
nn.ReLU(),
ResBlock(channels//2),
ResBlock(channels//2),
ResBlock(channels//2),
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
ResBlock(channels),
ResBlock(channels),
ResBlock(channels)
)
def forward(self, x):
return self.encoder(x)
class GptAsr(nn.Module):
MAX_SYMBOLS_PER_PHRASE = 200
MAX_MEL_FRAMES = 1000 // 4
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS
def __init__(self, layers=8, model_dim=512, heads=8):
super().__init__()
self.model_dim = model_dim
self.max_mel_frames = self.MAX_MEL_FRAMES
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.mel_encoder = MelEncoder(model_dim)
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim)
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+self.MAX_MEL_FRAMES, heads=heads,
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
def forward(self, mel_inputs, text_targets):
text_targets = F.pad(text_targets, (0, self.MAX_SYMBOLS_PER_PHRASE-text_targets.shape[1]))
text_emb = self.text_embedding(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1)
enc = self.gpt(emb)
# Compute loss
text_logits = self.final_norm(enc[:, self.MAX_MEL_FRAMES:])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
loss_text = F.cross_entropy(text_logits[:,:,1:], text_targets[:,:-1].long())
return loss_text.mean()
@register_model
def register_gpt_asr(opt_net, opt):
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
gpt = GptAsr()
l = gpt(torch.randn(2,80,800),
torch.randint(high=len(symbols), size=(2,180)))
print(l.shape)
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)

View File

@ -108,7 +108,7 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
# classes
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
@ -116,7 +116,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.stable = stable
self.causal = causal
self.non_causal_sequence_partition = non_causal_sequence_partition
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@ -141,10 +141,14 @@ class Attention(nn.Module):
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1)
if self.non_causal_sequence_partition > 0:
non_causal_mask = torch.ones((i, j), device=device)
non_causal_mask[:, :self.non_causal_sequence_partition] = 0
mask = mask * non_causal_mask
dots.masked_fill_(mask.bool(), mask_value)
attn = softmax(dots, dim=-1)
@ -162,21 +166,21 @@ class Transformer(nn.Module):
depth,
seq_len,
reversible = False,
causal = True,
heads = 8,
dim_head = 64,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
sparse_attn = False,
stable = False
stable = False,
non_causal_sequence_partition=0,
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)
for ind, sparse_attn in zip(range(depth), sparse_layer):
attn = Attention(dim, stable=stable, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
attn = Attention(dim, stable=stable, non_causal_sequence_partition = non_causal_sequence_partition, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn)),

View File

@ -282,7 +282,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_mozcv.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mozcv.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -60,7 +60,7 @@ class GeneratorInjector(Injector):
results = method(*params)
new_state = {}
if isinstance(self.output, list):
# Only dereference tuples or lists, not tensors.
# Only dereference tuples or lists, not tensors. IF YOU REACH THIS ERROR, REMOVE THE BRACES AROUND YOUR OUTPUTS IN THE YAML CONFIG
assert isinstance(results, list) or isinstance(results, tuple)
for i, k in enumerate(self.output):
new_state[k] = results[i]