forked from mrq/DL-Art-School
fix conditioning free
This commit is contained in:
parent
fc0b291b21
commit
c14bf6dfb2
|
@ -149,7 +149,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, timesteps, prior=None, conditioning_free=False):
|
def forward(self, x, timesteps, prior=None, conditioning_free=False):
|
||||||
if conditioning_free:
|
if conditioning_free:
|
||||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
|
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||||
else:
|
else:
|
||||||
code_emb = self.input_converter(prior)
|
code_emb = self.input_converter(prior)
|
||||||
|
|
||||||
|
@ -260,6 +260,7 @@ def test_cheater_model():
|
||||||
|
|
||||||
print_network(model)
|
print_network(model)
|
||||||
o = model(clip, ts, clip)
|
o = model(clip, ts, clip)
|
||||||
|
o = model(clip, ts, clip, conditioning_free=True)
|
||||||
pg = model.get_grad_norm_parameter_groups()
|
pg = model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
|
||||||
|
@ -274,6 +275,6 @@ def extract_cheater_encoder(in_f, out_f):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#test_local_attention_mask()
|
#test_local_attention_mask()
|
||||||
extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
|
#extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
|
||||||
test_cheater_model()
|
test_cheater_model()
|
||||||
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user