forked from mrq/DL-Art-School
fix code_emb
This commit is contained in:
parent
368dca18b1
commit
8c8efbe131
|
@ -210,8 +210,8 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
def timestep_independent(self, prior, expected_seq_len):
|
||||
if self.new_code_expansion:
|
||||
code_emb = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
|
||||
code_emb = self.ar_input(code_emb) if self.ar_prior else self.input_converter(code_emb)
|
||||
prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
|
||||
code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
|
||||
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
|
@ -732,14 +732,14 @@ def test_cheater_model():
|
|||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512,
|
||||
model_channels=1024, contraction_dim=512,
|
||||
prenet_channels=1024, num_heads=8,
|
||||
input_vec_dim=256, num_layers=12, prenet_layers=6,
|
||||
model_channels=1536, contraction_dim=768,
|
||||
prenet_channels=1024, num_heads=12,
|
||||
input_vec_dim=256, num_layers=20, prenet_layers=6,
|
||||
dropout=.1, new_code_expansion=True,
|
||||
)
|
||||
diff_weights = torch.load('extracted_diff.pth')
|
||||
model.diff.load_state_dict(diff_weights, strict=False)
|
||||
cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\19500_generator_ema.pth')
|
||||
#diff_weights = torch.load('extracted_diff.pth')
|
||||
#model.diff.load_state_dict(diff_weights, strict=False)
|
||||
cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\60000_generator_ema.pth')
|
||||
cheater_ar = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4,
|
||||
vqargs= {'positional_dims': 1, 'channels': 64,
|
||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||
|
|
Loading…
Reference in New Issue
Block a user