fix conditioning free
This commit is contained in:
parent
fc0b291b21
commit
c14bf6dfb2
|
@ -149,7 +149,7 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
def forward(self, x, timesteps, prior=None, conditioning_free=False):
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
else:
|
||||
code_emb = self.input_converter(prior)
|
||||
|
||||
|
@ -260,6 +260,7 @@ def test_cheater_model():
|
|||
|
||||
print_network(model)
|
||||
o = model(clip, ts, clip)
|
||||
o = model(clip, ts, clip, conditioning_free=True)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
||||
|
@ -274,6 +275,6 @@ def extract_cheater_encoder(in_f, out_f):
|
|||
|
||||
if __name__ == '__main__':
|
||||
#test_local_attention_mask()
|
||||
extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
|
||||
#extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
|
||||
test_cheater_model()
|
||||
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user