This commit is contained in:
James Betker 2022-07-12 22:48:46 -06:00
parent ebfe72d502
commit e23c322089

View File

@ -214,7 +214,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
cond_left = conditioning_input[:,:,:max(cond_start, 20)]
left_pt = cond_start-1
cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):]
right_pt = cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start))
right_pt = min(cond_right.shape[-1]-1, cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start)))
elif cond_left is None:
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
cond_pre = conditioning_input[:,:,:cond_start]
@ -313,12 +313,12 @@ def test_cheater_model():
# For music:
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
contraction_dim=512, num_heads=8, num_layers=40, dropout=0,
unconditioned_percentage=.4, checkpoint_conditioning=False,
regularization=True, new_cond=True)
print_network(model)
for cs in range(276,cl.shape[-1]-clip.shape[-1]):
o = model(clip, ts, cl, cond_start=cs)
#for cs in range(276,cl.shape[-1]-clip.shape[-1]):
# o = model(clip, ts, cl, cond_start=cs)
pg = model.get_grad_norm_parameter_groups()
def prmsz(lp):
sz = 0