diff --git a/codes/models/audio/music/transformer_diffusion.py b/codes/models/audio/music/transformer_diffusion.py index e85e232c..ae995060 100644 --- a/codes/models/audio/music/transformer_diffusion.py +++ b/codes/models/audio/music/transformer_diffusion.py @@ -24,6 +24,16 @@ def is_sequence(t): return t.dtype == torch.long +class MultiGroupEmbedding(nn.Module): + def __init__(self, tokens, groups, dim): + super().__init__() + self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) + + def forward(self, x): + h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] + return torch.cat(h, dim=-1) + + class TransformerDiffusion(nn.Module): """ A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? @@ -35,13 +45,12 @@ class TransformerDiffusion(nn.Module): num_layers=8, in_channels=256, in_latent_channels=512, - in_vectors=8, - in_groups=8, + token_count=8, + in_groups=None, out_channels=512, # mean and variance dropout=0, use_fp16=False, # Parameters for regularization. - layer_drop=.1, unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. ): super().__init__() @@ -52,7 +61,6 @@ class TransformerDiffusion(nn.Module): self.dropout = dropout self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 - self.layer_drop = layer_drop heads = model_channels//64 self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) @@ -79,7 +87,10 @@ class TransformerDiffusion(nn.Module): # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. - self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)]) + if in_groups is None: + self.embeddings = nn.Embedding(token_count, model_channels) + else: + self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), Encoder( @@ -142,8 +153,7 @@ class TransformerDiffusion(nn.Module): cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) cond_emb = self.conditioning_encoder(cond_emb)[:, 0] - code_emb = [embedding(codes[:, :, i]) for i, embedding in enumerate(self.embeddings)] - code_emb = torch.cat(code_emb, dim=-1) + code_emb = self.embeddings(codes) if prenet_latent is not None: latent_conditioning = self.latent_conditioner(prenet_latent) code_emb = code_emb + latent_conditioning * self.latent_fade @@ -242,6 +252,7 @@ class TransformerDiffusion(nn.Module): conds = torch.cat(conds, dim=-1) return conds.mean(dim=-1) + @register_model def register_transformer_diffusion(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs']) @@ -253,7 +264,7 @@ if __name__ == '__main__': aligned_sequence = torch.randint(0,8,(2,100,8)) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusion(512, layer_drop=.3, unconditioned_percentage=.5) + model = TransformerDiffusion(512, unconditioned_percentage=.5, in_groups=8) o = model(clip, ts, aligned_sequence, cond, return_code_pred=True) #o = model(clip, ts, aligned_sequence, cond, aligned_latent) diff --git a/codes/models/audio/tts/ctc_code_generator2.py b/codes/models/audio/tts/ctc_code_generator2.py deleted file mode 100644 index 694afaee..00000000 --- a/codes/models/audio/tts/ctc_code_generator2.py +++ /dev/null @@ -1,166 +0,0 @@ -import functools -import json - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import T5Config, T5ForConditionalGeneration - -from models.audio.tts.transformer_builders import null_position_embeddings -from models.audio.tts.unified_voice2 import ConditioningEncoder -from models.audio.tts.tacotron2.text.cleaners import english_cleaners -from trainer.networks import register_model -from utils.util import opt_get - - -class CtcCodeGenerator(nn.Module): - def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, checkpointing=True): - super().__init__() - self.max_pad = max_pad - self.max_repeat = max_repeat - self.start_token = self.max_repeat*self.max_pad+1 - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads) - self.embedding = nn.Embedding(ctc_codes, model_dim) - self.config = T5Config( - vocab_size=self.start_token+1, - d_model=model_dim, - d_kv=model_dim//num_heads, - d_ff=model_dim*4, - num_layers=layers, - num_heads=num_heads, - dropout_rate=dropout, - feed_forward_proj='gated-gelu', - use_cache=not checkpointing, - gradient_checkpointing=checkpointing, - tie_word_embeddings=False, - tie_encoder_decoder=False, - decoder_start_token_id=self.start_token, - pad_token_id=0, - ) - self.transformer = T5ForConditionalGeneration(self.config) - del self.transformer.encoder.embed_tokens - del self.transformer.shared - self.transformer.encoder.embed_tokens = functools.partial(null_position_embeddings, dim=model_dim) - - def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths): - max_len = unpadded_lengths.max() - codes = codes[:, :max_len] - separators = separators[:, :max_len] - repeats = repeats[:, :max_len] - if separators.max() > self.max_pad: - print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}") - separators = torch.clip(separators, 0, self.max_pad) - if repeats.max() > self.max_repeat: - print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}") - repeats = torch.clip(repeats, 0, self.max_repeat) - assert not torch.any(repeats < 1) - repeats = repeats - 1 # Per above, min(repeats) is 1; make it 0 to avoid wasting a prediction slot. - - assert codes.max() < 36, codes.max() - labels = separators + repeats * self.max_pad - labels = labels + 1 # We want '0' to be used as the EOS or padding token, so add 1. - for i in range(unpadded_lengths.shape[0]): - labels[i, unpadded_lengths[i]:] = 0 - - conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input - conds = [] - for j in range(conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - h = torch.cat([conds, self.embedding(codes)], dim=1) - - decoder_inputs = F.pad(labels, (1, 0), value=self.start_token)[:, :-1] - loss = self.transformer(inputs_embeds=h, decoder_input_ids=decoder_inputs, labels=labels).loss - return loss - - def generate(self, speech_conditioning_inputs, texts, **hf_generate_kwargs): - codes = [] - max_seq = 50 - for text in texts: - # First, generate CTC codes from the given texts. - vocab = json.loads('{" ": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "\'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}') - text = english_cleaners(text) - text = text.strip().upper() - cd = [] - for c in text: - if c not in vocab.keys(): - continue - cd.append(vocab[c]) - codes.append(torch.tensor(cd, device=speech_conditioning_inputs.device)) - max_seq = max(max_seq, codes[-1].shape[-1]) - # Collate - for i in range(len(codes)): - if codes[i].shape[-1] < max_seq: - codes[i] = F.pad(codes[i], (0, max_seq-codes[i].shape[-1])) - codes = torch.stack(codes, dim=0) - - conditioning_input = speech_conditioning_inputs.unsqueeze(1) if len(speech_conditioning_inputs.shape) == 3 else speech_conditioning_inputs - conds = [] - for j in range(conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - h = torch.cat([conds, self.embedding(codes)], dim=1) - generate = self.transformer.generate(inputs_embeds=h, max_length=codes.shape[-1]+1, min_length=codes.shape[-1]+1, - bos_token_id=self.start_token, - bad_words_ids=[[0], [self.start_token]], **hf_generate_kwargs) - # The HF generate API returns a sequence with the BOS token included, hence the +1s above. Remove it. - generate = generate[:, 1:] - - # De-compress the codes from the generated output - generate = generate - 1 # Remember above when we added 1 to the labels to avoid overlapping the EOS pad token? - pads = generate % self.max_pad - repeats = (generate // self.max_pad) + 1 - ctc_batch = [] - max_seq = 0 - for bc, bp, br in zip(codes, pads, repeats): - ctc = [] - for c, p, r in zip(bc, bp, br): - for _ in range(p): - ctc.append(0) - for _ in range(r): - ctc.append(c.item()) - ctc_batch.append(torch.tensor(ctc, device=speech_conditioning_inputs.device)) - max_seq = max(max_seq, ctc_batch[-1].shape[-1]) - - # Collate the batch - for i in range(len(ctc_batch)): - if ctc_batch[i].shape[-1] < max_seq: - ctc_batch[i] = F.pad(ctc_batch[i], (0, max_seq-ctc_batch[i].shape[-1])) - return torch.stack(ctc_batch, dim=0) - - -@register_model -def register_ctc_code_generator2(opt_net, opt): - return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {})) - - -def inf(): - sd = torch.load('D:\\dlas\\experiments\\train_encoder_build_ctc_alignments\\models\\24000_generator.pth', map_location='cpu') - model = CtcCodeGenerator(layers=10, checkpointing=False).eval() - model.load_state_dict(sd) - raw_batch = torch.load('raw_batch.pth') - with torch.no_grad(): - from scripts.audio.gen.speech_synthesis_utils import wav_to_mel - ref_mel = torch.cat([wav_to_mel(raw_batch['conditioning'][0])[:, :, :256], - wav_to_mel(raw_batch['conditioning'][0])[:, :, :256]], dim=0).unsqueeze(0) - loss = model(ref_mel, raw_batch['ctc_raw_codes'][0].unsqueeze(0), - raw_batch['ctc_pads'][0].unsqueeze(0), - raw_batch['ctc_repeats'][0].unsqueeze(0), - raw_batch['ctc_raw_lengths'][0].unsqueeze(0),) - #ref_mel = torch.cat([wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\atkins\\1.wav", 22050))[:, :, :256], - # wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\atkins\\2.wav", 22050))[:, :, :256]], dim=0).unsqueeze(0) - #ctc = model.generate(ref_mel, ["i suppose though it's too early for them"], num_beams=4, ) - print("Break") - - -if __name__ == '__main__': - inf() - - model = CtcCodeGenerator() - conds = torch.randn(4,2,80,600) - inps = torch.randint(0,36, (4, 300)) - pads = torch.randint(0,100, (4,300)) - repeats = torch.randint(0,20, (4,300)) - #loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30])) - #print(loss.shape) - #model.generate(conds, ["Hello, world!", "Ahoi!", "KKKKKK", "what's going on??"]) \ No newline at end of file