forked from mrq/DL-Art-School
Fix bug in conditioning segment fetching
This commit is contained in:
parent
e5859acff7
commit
455943779b
|
@ -203,9 +203,12 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
else:
|
||||
if self.training and self.conditioning_masking > 0:
|
||||
mask_prop = random.random() * self.conditioning_masking
|
||||
mask_len = min(int(N * mask_prop), 4)
|
||||
mask_len = max(int(N * mask_prop), 4)
|
||||
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
|
||||
seg_start = random.randint(8, (N-mask_len)) + cond_start
|
||||
# Readjust mask_len to ensure at least 8 sequence elements on the end as well
|
||||
mask_len = min(mask_len, conditioning_input.shape[-1]-seg_start-8)
|
||||
assert mask_len > 0, f"Error with mask_len: {conditioning_input.shape[-1], seg_start, mask_len, mask_prop, N}"
|
||||
seg_end = seg_start+mask_len
|
||||
conditioning_input[:,:,seg_start:seg_end] = 0
|
||||
else:
|
||||
|
@ -276,14 +279,14 @@ def register_tfdpc5(opt_net, opt):
|
|||
|
||||
|
||||
def test_cheater_model():
|
||||
clip = torch.randn(2, 256, 200)
|
||||
cl = torch.randn(2, 256, 500)
|
||||
clip = torch.randn(2, 256, 350)
|
||||
cl = torch.randn(2, 256, 646)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||
unconditioned_percentage=.4, conditioning_masking=.5,
|
||||
unconditioned_percentage=.4, conditioning_masking=.4,
|
||||
segregrate_conditioning_segments=True)
|
||||
print_network(model)
|
||||
for k in range(100):
|
||||
|
|
Loading…
Reference in New Issue
Block a user