diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index dec869a4..c1d2c24f 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -10,6 +10,7 @@ import torch.nn.functional as F import torch.utils.data import torchaudio from tqdm import tqdm +from transformers import Wav2Vec2CTCTokenizer from data.audio.paired_voice_audio_dataset import CharacterTokenizer from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips @@ -147,6 +148,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): rcodes = rcodes[:self.max_text_len] repeats = rcodes[:self.max_text_len] seps = seps[:self.max_text_len] + return { 'ctc_raw_codes': rcodes, 'ctc_separators': seps, @@ -284,7 +286,8 @@ if __name__ == '__main__': 'num_conditioning_candidates': 2, 'conditioning_length': 102400, 'use_bpe_tokenizer': True, - 'load_aligned_codes': False, + 'load_aligned_codes': True, + 'produce_ctc_metadata': True, } from data import create_dataset, create_dataloader diff --git a/codes/models/audio/tts/transformer_diffusion_tts.py b/codes/models/audio/tts/transformer_diffusion_tts.py index 9e14a7a7..953494a7 100644 --- a/codes/models/audio/tts/transformer_diffusion_tts.py +++ b/codes/models/audio/tts/transformer_diffusion_tts.py @@ -101,7 +101,7 @@ class TransformerDiffusion(nn.Module): rotary_pos_emb=True, ) self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels) - self.type_embedding = nn.Embedding(types) + self.type_embedding = nn.Embedding(types, model_channels) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally @@ -254,7 +254,8 @@ if __name__ == '__main__': cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) clvp = torch.randn(2,768) + type = torch.LongTensor([0,1]) model = TransformerDiffusion(512, unconditioned_percentage=.5, in_groups=8) - o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, return_code_pred=True) + o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type, return_code_pred=True) #o = model(clip, ts, aligned_sequence, cond, aligned_latent)