forked from mrq/DL-Art-School
slight rework
This commit is contained in:
parent
48aab2babe
commit
8b4b5ffa72
|
@ -24,6 +24,16 @@ def is_sequence(t):
|
||||||
return t.dtype == torch.long
|
return t.dtype == torch.long
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGroupEmbedding(nn.Module):
|
||||||
|
def __init__(self, tokens, groups, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||||
|
return torch.cat(h, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusion(nn.Module):
|
class TransformerDiffusion(nn.Module):
|
||||||
"""
|
"""
|
||||||
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
|
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
|
||||||
|
@ -35,13 +45,12 @@ class TransformerDiffusion(nn.Module):
|
||||||
num_layers=8,
|
num_layers=8,
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
in_latent_channels=512,
|
in_latent_channels=512,
|
||||||
in_vectors=8,
|
token_count=8,
|
||||||
in_groups=8,
|
in_groups=None,
|
||||||
out_channels=512, # mean and variance
|
out_channels=512, # mean and variance
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
layer_drop=.1,
|
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -52,7 +61,6 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
self.layer_drop = layer_drop
|
|
||||||
heads = model_channels//64
|
heads = model_channels//64
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||||
|
@ -79,7 +87,10 @@ class TransformerDiffusion(nn.Module):
|
||||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
# transformer network.
|
# transformer network.
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)])
|
if in_groups is None:
|
||||||
|
self.embeddings = nn.Embedding(token_count, model_channels)
|
||||||
|
else:
|
||||||
|
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||||
self.latent_conditioner = nn.Sequential(
|
self.latent_conditioner = nn.Sequential(
|
||||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||||
Encoder(
|
Encoder(
|
||||||
|
@ -142,8 +153,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
|
cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
|
||||||
cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
|
cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
|
||||||
|
|
||||||
code_emb = [embedding(codes[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
code_emb = self.embeddings(codes)
|
||||||
code_emb = torch.cat(code_emb, dim=-1)
|
|
||||||
if prenet_latent is not None:
|
if prenet_latent is not None:
|
||||||
latent_conditioning = self.latent_conditioner(prenet_latent)
|
latent_conditioning = self.latent_conditioner(prenet_latent)
|
||||||
code_emb = code_emb + latent_conditioning * self.latent_fade
|
code_emb = code_emb + latent_conditioning * self.latent_fade
|
||||||
|
@ -242,6 +252,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
conds = torch.cat(conds, dim=-1)
|
conds = torch.cat(conds, dim=-1)
|
||||||
return conds.mean(dim=-1)
|
return conds.mean(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion(opt_net, opt):
|
def register_transformer_diffusion(opt_net, opt):
|
||||||
return TransformerDiffusion(**opt_net['kwargs'])
|
return TransformerDiffusion(**opt_net['kwargs'])
|
||||||
|
@ -253,7 +264,7 @@ if __name__ == '__main__':
|
||||||
aligned_sequence = torch.randint(0,8,(2,100,8))
|
aligned_sequence = torch.randint(0,8,(2,100,8))
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = TransformerDiffusion(512, layer_drop=.3, unconditioned_percentage=.5)
|
model = TransformerDiffusion(512, unconditioned_percentage=.5, in_groups=8)
|
||||||
o = model(clip, ts, aligned_sequence, cond, return_code_pred=True)
|
o = model(clip, ts, aligned_sequence, cond, return_code_pred=True)
|
||||||
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
||||||
|
|
||||||
|
|
|
@ -1,166 +0,0 @@
|
||||||
import functools
|
|
||||||
import json
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers import T5Config, T5ForConditionalGeneration
|
|
||||||
|
|
||||||
from models.audio.tts.transformer_builders import null_position_embeddings
|
|
||||||
from models.audio.tts.unified_voice2 import ConditioningEncoder
|
|
||||||
from models.audio.tts.tacotron2.text.cleaners import english_cleaners
|
|
||||||
from trainer.networks import register_model
|
|
||||||
from utils.util import opt_get
|
|
||||||
|
|
||||||
|
|
||||||
class CtcCodeGenerator(nn.Module):
|
|
||||||
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, checkpointing=True):
|
|
||||||
super().__init__()
|
|
||||||
self.max_pad = max_pad
|
|
||||||
self.max_repeat = max_repeat
|
|
||||||
self.start_token = self.max_repeat*self.max_pad+1
|
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads)
|
|
||||||
self.embedding = nn.Embedding(ctc_codes, model_dim)
|
|
||||||
self.config = T5Config(
|
|
||||||
vocab_size=self.start_token+1,
|
|
||||||
d_model=model_dim,
|
|
||||||
d_kv=model_dim//num_heads,
|
|
||||||
d_ff=model_dim*4,
|
|
||||||
num_layers=layers,
|
|
||||||
num_heads=num_heads,
|
|
||||||
dropout_rate=dropout,
|
|
||||||
feed_forward_proj='gated-gelu',
|
|
||||||
use_cache=not checkpointing,
|
|
||||||
gradient_checkpointing=checkpointing,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
tie_encoder_decoder=False,
|
|
||||||
decoder_start_token_id=self.start_token,
|
|
||||||
pad_token_id=0,
|
|
||||||
)
|
|
||||||
self.transformer = T5ForConditionalGeneration(self.config)
|
|
||||||
del self.transformer.encoder.embed_tokens
|
|
||||||
del self.transformer.shared
|
|
||||||
self.transformer.encoder.embed_tokens = functools.partial(null_position_embeddings, dim=model_dim)
|
|
||||||
|
|
||||||
def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths):
|
|
||||||
max_len = unpadded_lengths.max()
|
|
||||||
codes = codes[:, :max_len]
|
|
||||||
separators = separators[:, :max_len]
|
|
||||||
repeats = repeats[:, :max_len]
|
|
||||||
if separators.max() > self.max_pad:
|
|
||||||
print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}")
|
|
||||||
separators = torch.clip(separators, 0, self.max_pad)
|
|
||||||
if repeats.max() > self.max_repeat:
|
|
||||||
print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
|
|
||||||
repeats = torch.clip(repeats, 0, self.max_repeat)
|
|
||||||
assert not torch.any(repeats < 1)
|
|
||||||
repeats = repeats - 1 # Per above, min(repeats) is 1; make it 0 to avoid wasting a prediction slot.
|
|
||||||
|
|
||||||
assert codes.max() < 36, codes.max()
|
|
||||||
labels = separators + repeats * self.max_pad
|
|
||||||
labels = labels + 1 # We want '0' to be used as the EOS or padding token, so add 1.
|
|
||||||
for i in range(unpadded_lengths.shape[0]):
|
|
||||||
labels[i, unpadded_lengths[i]:] = 0
|
|
||||||
|
|
||||||
conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
|
||||||
conds = []
|
|
||||||
for j in range(conditioning_input.shape[1]):
|
|
||||||
conds.append(self.conditioning_encoder(conditioning_input[:, j]))
|
|
||||||
conds = torch.stack(conds, dim=1)
|
|
||||||
h = torch.cat([conds, self.embedding(codes)], dim=1)
|
|
||||||
|
|
||||||
decoder_inputs = F.pad(labels, (1, 0), value=self.start_token)[:, :-1]
|
|
||||||
loss = self.transformer(inputs_embeds=h, decoder_input_ids=decoder_inputs, labels=labels).loss
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def generate(self, speech_conditioning_inputs, texts, **hf_generate_kwargs):
|
|
||||||
codes = []
|
|
||||||
max_seq = 50
|
|
||||||
for text in texts:
|
|
||||||
# First, generate CTC codes from the given texts.
|
|
||||||
vocab = json.loads('{" ": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "\'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}')
|
|
||||||
text = english_cleaners(text)
|
|
||||||
text = text.strip().upper()
|
|
||||||
cd = []
|
|
||||||
for c in text:
|
|
||||||
if c not in vocab.keys():
|
|
||||||
continue
|
|
||||||
cd.append(vocab[c])
|
|
||||||
codes.append(torch.tensor(cd, device=speech_conditioning_inputs.device))
|
|
||||||
max_seq = max(max_seq, codes[-1].shape[-1])
|
|
||||||
# Collate
|
|
||||||
for i in range(len(codes)):
|
|
||||||
if codes[i].shape[-1] < max_seq:
|
|
||||||
codes[i] = F.pad(codes[i], (0, max_seq-codes[i].shape[-1]))
|
|
||||||
codes = torch.stack(codes, dim=0)
|
|
||||||
|
|
||||||
conditioning_input = speech_conditioning_inputs.unsqueeze(1) if len(speech_conditioning_inputs.shape) == 3 else speech_conditioning_inputs
|
|
||||||
conds = []
|
|
||||||
for j in range(conditioning_input.shape[1]):
|
|
||||||
conds.append(self.conditioning_encoder(conditioning_input[:, j]))
|
|
||||||
conds = torch.stack(conds, dim=1)
|
|
||||||
h = torch.cat([conds, self.embedding(codes)], dim=1)
|
|
||||||
generate = self.transformer.generate(inputs_embeds=h, max_length=codes.shape[-1]+1, min_length=codes.shape[-1]+1,
|
|
||||||
bos_token_id=self.start_token,
|
|
||||||
bad_words_ids=[[0], [self.start_token]], **hf_generate_kwargs)
|
|
||||||
# The HF generate API returns a sequence with the BOS token included, hence the +1s above. Remove it.
|
|
||||||
generate = generate[:, 1:]
|
|
||||||
|
|
||||||
# De-compress the codes from the generated output
|
|
||||||
generate = generate - 1 # Remember above when we added 1 to the labels to avoid overlapping the EOS pad token?
|
|
||||||
pads = generate % self.max_pad
|
|
||||||
repeats = (generate // self.max_pad) + 1
|
|
||||||
ctc_batch = []
|
|
||||||
max_seq = 0
|
|
||||||
for bc, bp, br in zip(codes, pads, repeats):
|
|
||||||
ctc = []
|
|
||||||
for c, p, r in zip(bc, bp, br):
|
|
||||||
for _ in range(p):
|
|
||||||
ctc.append(0)
|
|
||||||
for _ in range(r):
|
|
||||||
ctc.append(c.item())
|
|
||||||
ctc_batch.append(torch.tensor(ctc, device=speech_conditioning_inputs.device))
|
|
||||||
max_seq = max(max_seq, ctc_batch[-1].shape[-1])
|
|
||||||
|
|
||||||
# Collate the batch
|
|
||||||
for i in range(len(ctc_batch)):
|
|
||||||
if ctc_batch[i].shape[-1] < max_seq:
|
|
||||||
ctc_batch[i] = F.pad(ctc_batch[i], (0, max_seq-ctc_batch[i].shape[-1]))
|
|
||||||
return torch.stack(ctc_batch, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_ctc_code_generator2(opt_net, opt):
|
|
||||||
return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {}))
|
|
||||||
|
|
||||||
|
|
||||||
def inf():
|
|
||||||
sd = torch.load('D:\\dlas\\experiments\\train_encoder_build_ctc_alignments\\models\\24000_generator.pth', map_location='cpu')
|
|
||||||
model = CtcCodeGenerator(layers=10, checkpointing=False).eval()
|
|
||||||
model.load_state_dict(sd)
|
|
||||||
raw_batch = torch.load('raw_batch.pth')
|
|
||||||
with torch.no_grad():
|
|
||||||
from scripts.audio.gen.speech_synthesis_utils import wav_to_mel
|
|
||||||
ref_mel = torch.cat([wav_to_mel(raw_batch['conditioning'][0])[:, :, :256],
|
|
||||||
wav_to_mel(raw_batch['conditioning'][0])[:, :, :256]], dim=0).unsqueeze(0)
|
|
||||||
loss = model(ref_mel, raw_batch['ctc_raw_codes'][0].unsqueeze(0),
|
|
||||||
raw_batch['ctc_pads'][0].unsqueeze(0),
|
|
||||||
raw_batch['ctc_repeats'][0].unsqueeze(0),
|
|
||||||
raw_batch['ctc_raw_lengths'][0].unsqueeze(0),)
|
|
||||||
#ref_mel = torch.cat([wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\atkins\\1.wav", 22050))[:, :, :256],
|
|
||||||
# wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\atkins\\2.wav", 22050))[:, :, :256]], dim=0).unsqueeze(0)
|
|
||||||
#ctc = model.generate(ref_mel, ["i suppose though it's too early for them"], num_beams=4, )
|
|
||||||
print("Break")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
inf()
|
|
||||||
|
|
||||||
model = CtcCodeGenerator()
|
|
||||||
conds = torch.randn(4,2,80,600)
|
|
||||||
inps = torch.randint(0,36, (4, 300))
|
|
||||||
pads = torch.randint(0,100, (4,300))
|
|
||||||
repeats = torch.randint(0,20, (4,300))
|
|
||||||
#loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30]))
|
|
||||||
#print(loss.shape)
|
|
||||||
#model.generate(conds, ["Hello, world!", "Ahoi!", "KKKKKK", "what's going on??"])
|
|
Loading…
Reference in New Issue
Block a user