Fix bug in conditioning segment fetching

This commit is contained in:
James Betker 2022-07-04 08:16:14 -06:00
parent e5859acff7
commit 455943779b

View File

@ -203,9 +203,12 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
else:
if self.training and self.conditioning_masking > 0:
mask_prop = random.random() * self.conditioning_masking
mask_len = min(int(N * mask_prop), 4)
mask_len = max(int(N * mask_prop), 4)
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
seg_start = random.randint(8, (N-mask_len)) + cond_start
# Readjust mask_len to ensure at least 8 sequence elements on the end as well
mask_len = min(mask_len, conditioning_input.shape[-1]-seg_start-8)
assert mask_len > 0, f"Error with mask_len: {conditioning_input.shape[-1], seg_start, mask_len, mask_prop, N}"
seg_end = seg_start+mask_len
conditioning_input[:,:,seg_start:seg_end] = 0
else:
@ -276,14 +279,14 @@ def register_tfdpc5(opt_net, opt):
def test_cheater_model():
clip = torch.randn(2, 256, 200)
cl = torch.randn(2, 256, 500)
clip = torch.randn(2, 256, 350)
cl = torch.randn(2, 256, 646)
ts = torch.LongTensor([600, 600])
# For music:
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
unconditioned_percentage=.4, conditioning_masking=.5,
unconditioned_percentage=.4, conditioning_masking=.4,
segregrate_conditioning_segments=True)
print_network(model)
for k in range(100):