forked from mrq/DL-Art-School
add an autoregressive ctc code generator
This commit is contained in:
parent
7f4fc55344
commit
8fb147e8ab
|
@ -118,7 +118,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
codes, repeats, pads = [], [], [0]
|
codes, repeats, pads = [], [], [0]
|
||||||
for val, group in grouped:
|
for val, group in grouped:
|
||||||
if val == 0:
|
if val == 0:
|
||||||
pads[-1] = len(list(group))
|
pads[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it.
|
||||||
else:
|
else:
|
||||||
codes.append(val)
|
codes.append(val)
|
||||||
repeats.append(len(list(group)))
|
repeats.append(len(list(group)))
|
||||||
|
|
94
codes/models/gpt_voice/ctc_code_generator2.py
Normal file
94
codes/models/gpt_voice/ctc_code_generator2.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import T5Config, T5Model
|
||||||
|
from x_transformers import Encoder, XTransformer
|
||||||
|
|
||||||
|
from models.gpt_voice.transformer_builders import null_position_embeddings
|
||||||
|
from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer
|
||||||
|
from models.gpt_voice.unified_voice2 import ConditioningEncoder
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
class CtcCodeGenerator(nn.Module):
|
||||||
|
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=120, max_repeat=30, checkpointing=True):
|
||||||
|
super().__init__()
|
||||||
|
self.max_pad = max_pad
|
||||||
|
self.max_repeat = max_repeat
|
||||||
|
self.start_token = (self.max_repeat+1)*(self.max_pad+1)+1
|
||||||
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads)
|
||||||
|
self.embedding = nn.Embedding(ctc_codes, model_dim)
|
||||||
|
self.dec_embedding = nn.Embedding(self.start_token+1, model_dim)
|
||||||
|
self.config = T5Config(
|
||||||
|
vocab_size=1, # T5 embedding will be removed and replaced with custom embedding.
|
||||||
|
d_model=model_dim,
|
||||||
|
d_kv=model_dim//num_heads,
|
||||||
|
d_ff=model_dim*4,
|
||||||
|
num_layers=layers,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dropout_rate=dropout,
|
||||||
|
feed_forward_proj='gated-gelu',
|
||||||
|
use_cache=not checkpointing,
|
||||||
|
gradient_checkpointing=checkpointing
|
||||||
|
)
|
||||||
|
self.transformer = T5Model(self.config)
|
||||||
|
del self.transformer.encoder.embed_tokens
|
||||||
|
del self.transformer.decoder.embed_tokens
|
||||||
|
self.transformer.encoder.embed_tokens = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
self.transformer.decoder.embed_tokens = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
self.output_layer = nn.Linear(model_dim, self.start_token+1)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, conditioning_input, codes, pads, repeats, unpadded_lengths):
|
||||||
|
max_len = unpadded_lengths.max()
|
||||||
|
codes = codes[:, :max_len]
|
||||||
|
pads = pads[:, :max_len]
|
||||||
|
repeats = repeats[:, :max_len]
|
||||||
|
|
||||||
|
if pads.max() > self.max_pad:
|
||||||
|
print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}")
|
||||||
|
pads = torch.clip(pads, 0, self.max_pad)
|
||||||
|
if repeats.max() > self.max_repeat:
|
||||||
|
print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
|
||||||
|
repeats = torch.clip(repeats, 0, self.max_repeat)
|
||||||
|
assert codes.max() < 36, codes.max()
|
||||||
|
|
||||||
|
conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
||||||
|
conds = []
|
||||||
|
for j in range(conditioning_input.shape[1]):
|
||||||
|
conds.append(self.conditioning_encoder(conditioning_input[:, j]))
|
||||||
|
conds = torch.stack(conds, dim=1)
|
||||||
|
h = torch.cat([conds, self.embedding(codes)], dim=1)
|
||||||
|
|
||||||
|
labels = pads + repeats * self.max_pad + 1
|
||||||
|
for i in range(unpadded_lengths.shape[0]):
|
||||||
|
labels[i, unpadded_lengths[i]:] = 0
|
||||||
|
labels_in = F.pad(labels, (1,0), value=self.start_token)
|
||||||
|
h_dec = self.dec_embedding(labels_in)
|
||||||
|
|
||||||
|
h = self.transformer(inputs_embeds=h, decoder_inputs_embeds=h_dec).last_hidden_state
|
||||||
|
logits = self.output_layer(h)
|
||||||
|
logits = logits.permute(0,2,1)[:,:,:-1] # Strip off the last token. There is no "stop" token here, so this is just an irrelevant prediction on some future that doesn't actually exist.
|
||||||
|
loss = F.cross_entropy(logits, labels, reduction='none')
|
||||||
|
|
||||||
|
# Ignore the first predictions of the sequences. This corresponds to the padding for the first CTC character, which is pretty much random and cannot be predicted.
|
||||||
|
#loss = loss[1:].mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_ctc_code_generator2(opt_net, opt):
|
||||||
|
return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = CtcCodeGenerator()
|
||||||
|
conds = torch.randn(4,2,80,600)
|
||||||
|
inps = torch.randint(0,36, (4, 300))
|
||||||
|
pads = torch.randint(0,100, (4,300))
|
||||||
|
repeats = torch.randint(0,20, (4,300))
|
||||||
|
loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30]))
|
||||||
|
print(loss.shape)
|
|
@ -299,7 +299,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts6_upsample.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user