forked from mrq/DL-Art-School
fix type bug
This commit is contained in:
parent
0659fe3d1e
commit
5efeee6b97
|
@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.data
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2CTCTokenizer
|
||||
|
||||
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
|
@ -147,6 +148,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
rcodes = rcodes[:self.max_text_len]
|
||||
repeats = rcodes[:self.max_text_len]
|
||||
seps = seps[:self.max_text_len]
|
||||
|
||||
return {
|
||||
'ctc_raw_codes': rcodes,
|
||||
'ctc_separators': seps,
|
||||
|
@ -284,7 +286,8 @@ if __name__ == '__main__':
|
|||
'num_conditioning_candidates': 2,
|
||||
'conditioning_length': 102400,
|
||||
'use_bpe_tokenizer': True,
|
||||
'load_aligned_codes': False,
|
||||
'load_aligned_codes': True,
|
||||
'produce_ctc_metadata': True,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ class TransformerDiffusion(nn.Module):
|
|||
rotary_pos_emb=True,
|
||||
)
|
||||
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels)
|
||||
self.type_embedding = nn.Embedding(types)
|
||||
self.type_embedding = nn.Embedding(types, model_channels)
|
||||
|
||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||
|
@ -254,7 +254,8 @@ if __name__ == '__main__':
|
|||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
clvp = torch.randn(2,768)
|
||||
type = torch.LongTensor([0,1])
|
||||
model = TransformerDiffusion(512, unconditioned_percentage=.5, in_groups=8)
|
||||
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, return_code_pred=True)
|
||||
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type, return_code_pred=True)
|
||||
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user