fix conditioning free

This commit is contained in:
James Betker 2022-07-19 18:04:49 -06:00
parent fc0b291b21
commit c14bf6dfb2

View File

@ -149,7 +149,7 @@ class TransformerDiffusion(nn.Module):
def forward(self, x, timesteps, prior=None, conditioning_free=False):
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
else:
code_emb = self.input_converter(prior)
@ -260,6 +260,7 @@ def test_cheater_model():
print_network(model)
o = model(clip, ts, clip)
o = model(clip, ts, clip, conditioning_free=True)
pg = model.get_grad_norm_parameter_groups()
@ -274,6 +275,6 @@ def extract_cheater_encoder(in_f, out_f):
if __name__ == '__main__':
#test_local_attention_mask()
extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
#extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
test_cheater_model()
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)