forked from mrq/DL-Art-School
uhh2.0
This commit is contained in:
parent
ebfe72d502
commit
e23c322089
|
@ -214,7 +214,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
cond_left = conditioning_input[:,:,:max(cond_start, 20)]
|
cond_left = conditioning_input[:,:,:max(cond_start, 20)]
|
||||||
left_pt = cond_start-1
|
left_pt = cond_start-1
|
||||||
cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):]
|
cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):]
|
||||||
right_pt = cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start))
|
right_pt = min(cond_right.shape[-1]-1, cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start)))
|
||||||
elif cond_left is None:
|
elif cond_left is None:
|
||||||
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
||||||
cond_pre = conditioning_input[:,:,:cond_start]
|
cond_pre = conditioning_input[:,:,:cond_start]
|
||||||
|
@ -313,12 +313,12 @@ def test_cheater_model():
|
||||||
|
|
||||||
# For music:
|
# For music:
|
||||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
contraction_dim=512, num_heads=8, num_layers=40, dropout=0,
|
||||||
unconditioned_percentage=.4, checkpoint_conditioning=False,
|
unconditioned_percentage=.4, checkpoint_conditioning=False,
|
||||||
regularization=True, new_cond=True)
|
regularization=True, new_cond=True)
|
||||||
print_network(model)
|
print_network(model)
|
||||||
for cs in range(276,cl.shape[-1]-clip.shape[-1]):
|
#for cs in range(276,cl.shape[-1]-clip.shape[-1]):
|
||||||
o = model(clip, ts, cl, cond_start=cs)
|
# o = model(clip, ts, cl, cond_start=cs)
|
||||||
pg = model.get_grad_norm_parameter_groups()
|
pg = model.get_grad_norm_parameter_groups()
|
||||||
def prmsz(lp):
|
def prmsz(lp):
|
||||||
sz = 0
|
sz = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user