fix type bug

This commit is contained in:
James Betker 2022-05-27 11:19:30 -06:00
parent 0659fe3d1e
commit 5efeee6b97
2 changed files with 7 additions and 3 deletions

View File

@ -10,6 +10,7 @@ import torch.nn.functional as F
import torch.utils.data
import torchaudio
from tqdm import tqdm
from transformers import Wav2Vec2CTCTokenizer
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
@ -147,6 +148,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
rcodes = rcodes[:self.max_text_len]
repeats = rcodes[:self.max_text_len]
seps = seps[:self.max_text_len]
return {
'ctc_raw_codes': rcodes,
'ctc_separators': seps,
@ -284,7 +286,8 @@ if __name__ == '__main__':
'num_conditioning_candidates': 2,
'conditioning_length': 102400,
'use_bpe_tokenizer': True,
'load_aligned_codes': False,
'load_aligned_codes': True,
'produce_ctc_metadata': True,
}
from data import create_dataset, create_dataloader

View File

@ -101,7 +101,7 @@ class TransformerDiffusion(nn.Module):
rotary_pos_emb=True,
)
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels)
self.type_embedding = nn.Embedding(types)
self.type_embedding = nn.Embedding(types, model_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
@ -254,7 +254,8 @@ if __name__ == '__main__':
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
clvp = torch.randn(2,768)
type = torch.LongTensor([0,1])
model = TransformerDiffusion(512, unconditioned_percentage=.5, in_groups=8)
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, return_code_pred=True)
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type, return_code_pred=True)
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)