uhh2.0
This commit is contained in:
parent
ebfe72d502
commit
e23c322089
|
@ -214,7 +214,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
cond_left = conditioning_input[:,:,:max(cond_start, 20)]
|
||||
left_pt = cond_start-1
|
||||
cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):]
|
||||
right_pt = cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start))
|
||||
right_pt = min(cond_right.shape[-1]-1, cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start)))
|
||||
elif cond_left is None:
|
||||
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
||||
cond_pre = conditioning_input[:,:,:cond_start]
|
||||
|
@ -313,12 +313,12 @@ def test_cheater_model():
|
|||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||
contraction_dim=512, num_heads=8, num_layers=40, dropout=0,
|
||||
unconditioned_percentage=.4, checkpoint_conditioning=False,
|
||||
regularization=True, new_cond=True)
|
||||
print_network(model)
|
||||
for cs in range(276,cl.shape[-1]-clip.shape[-1]):
|
||||
o = model(clip, ts, cl, cond_start=cs)
|
||||
#for cs in range(276,cl.shape[-1]-clip.shape[-1]):
|
||||
# o = model(clip, ts, cl, cond_start=cs)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
def prmsz(lp):
|
||||
sz = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user