Improve conditioning separation logic
This commit is contained in:
parent
802998674e
commit
2b128730e7
|
@ -126,10 +126,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
num_heads=8,
|
||||
dropout=0,
|
||||
use_fp16=False,
|
||||
segregrate_conditioning_segments=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
conditioning_masking=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -137,10 +135,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
self.model_channels = model_channels
|
||||
self.time_embed_dim = time_embed_dim
|
||||
self.out_channels = out_channels
|
||||
self.segregrate_conditioning_segments = segregrate_conditioning_segments
|
||||
self.dropout = dropout
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.conditioning_masking = conditioning_masking
|
||||
self.enable_fp16 = use_fp16
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||
|
@ -201,28 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
if custom_conditioning_fetcher is not None:
|
||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
||||
else:
|
||||
if self.training and self.conditioning_masking > 0:
|
||||
mask_prop = random.random() * self.conditioning_masking
|
||||
mask_len = max(int(N * mask_prop), 16)
|
||||
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
|
||||
seg_start = random.randint(8, (N-mask_len)) + cond_start
|
||||
# Readjust mask_len to ensure at least 8 sequence elements on the end as well
|
||||
mask_len = min(mask_len, conditioning_input.shape[-1]-seg_start-8)
|
||||
assert mask_len > 0, f"Error with mask_len: {conditioning_input.shape[-1], seg_start, mask_len, mask_prop, N}"
|
||||
seg_end = seg_start+mask_len
|
||||
conditioning_input[:,:,seg_start:seg_end] = 0
|
||||
else:
|
||||
seg_start = cond_start + N // 2
|
||||
seg_end = seg_start
|
||||
if self.segregrate_conditioning_segments:
|
||||
cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb)
|
||||
cs = cond_enc1[:,:,cond_start]
|
||||
cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb)
|
||||
ce = cond_enc2[:,:,(N+cond_start)-seg_end]
|
||||
else:
|
||||
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
|
||||
cs = cond_enc[:,:,cond_start]
|
||||
ce = cond_enc[:,:,N+cond_start]
|
||||
assert conditioning_input.shape[-1] - cond_start - N > 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
||||
cond_pre = conditioning_input[:,:,:cond_start]
|
||||
cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
|
||||
cond_post = conditioning_input[:,:,N+cond_start:]
|
||||
|
||||
# Break up conditioning input into two random segments aligned with the input.
|
||||
MIN_MARGIN = 8
|
||||
assert N > (MIN_MARGIN*2+4), f"Input size too small. Was {N} but requires at least {MIN_MARGIN*2+4}"
|
||||
break_pt = random.randint(2, N-MIN_MARGIN*2) + MIN_MARGIN
|
||||
cond_left = cond_aligned[:,:,:break_pt]
|
||||
cond_right = cond_aligned[:,:,break_pt:]
|
||||
|
||||
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
|
||||
to_remove = random.randint(0, cond_left.shape[-1]-MIN_MARGIN)
|
||||
cond_left = cond_left[:,:,:-to_remove]
|
||||
to_remove = random.randint(0, cond_right.shape[-1]-MIN_MARGIN)
|
||||
cond_right = cond_right[:,:,to_remove:]
|
||||
|
||||
# Concatenate the _pre and _post back on.
|
||||
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
|
||||
cond_right_full = torch.cat([cond_right, cond_post], dim=-1)
|
||||
|
||||
# Propagate through the encoder.
|
||||
cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb)
|
||||
cs = cond_left_enc[:,:,cond_start]
|
||||
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
|
||||
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
|
||||
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
||||
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
|
||||
return cond
|
||||
|
@ -286,8 +287,7 @@ def test_cheater_model():
|
|||
# For music:
|
||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||
unconditioned_percentage=.4, conditioning_masking=.4,
|
||||
segregrate_conditioning_segments=True)
|
||||
unconditioned_percentage=.4)
|
||||
print_network(model)
|
||||
for k in range(100):
|
||||
o = model(clip, ts, cl)
|
||||
|
|
Loading…
Reference in New Issue
Block a user