diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 38c0a09b..8d327baa 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -214,7 +214,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond_left = conditioning_input[:,:,:max(cond_start, 20)] left_pt = cond_start-1 cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):] - right_pt = cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start)) + right_pt = min(cond_right.shape[-1]-1, cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start))) elif cond_left is None: assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}' cond_pre = conditioning_input[:,:,:cond_start] @@ -313,12 +313,12 @@ def test_cheater_model(): # For music: model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, - contraction_dim=512, num_heads=8, num_layers=15, dropout=0, + contraction_dim=512, num_heads=8, num_layers=40, dropout=0, unconditioned_percentage=.4, checkpoint_conditioning=False, regularization=True, new_cond=True) print_network(model) - for cs in range(276,cl.shape[-1]-clip.shape[-1]): - o = model(clip, ts, cl, cond_start=cs) + #for cs in range(276,cl.shape[-1]-clip.shape[-1]): + # o = model(clip, ts, cl, cond_start=cs) pg = model.get_grad_norm_parameter_groups() def prmsz(lp): sz = 0