Improve conditioning separation logic

This commit is contained in:
James Betker 2022-07-05 10:30:28 -06:00
parent 802998674e
commit 2b128730e7

View File

@ -126,10 +126,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
num_heads=8,
dropout=0,
use_fp16=False,
segregrate_conditioning_segments=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
conditioning_masking=0,
):
super().__init__()
@ -137,10 +135,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
self.model_channels = model_channels
self.time_embed_dim = time_embed_dim
self.out_channels = out_channels
self.segregrate_conditioning_segments = segregrate_conditioning_segments
self.dropout = dropout
self.unconditioned_percentage = unconditioned_percentage
self.conditioning_masking = conditioning_masking
self.enable_fp16 = use_fp16
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
@ -201,28 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
if custom_conditioning_fetcher is not None:
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
else:
if self.training and self.conditioning_masking > 0:
mask_prop = random.random() * self.conditioning_masking
mask_len = max(int(N * mask_prop), 16)
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
seg_start = random.randint(8, (N-mask_len)) + cond_start
# Readjust mask_len to ensure at least 8 sequence elements on the end as well
mask_len = min(mask_len, conditioning_input.shape[-1]-seg_start-8)
assert mask_len > 0, f"Error with mask_len: {conditioning_input.shape[-1], seg_start, mask_len, mask_prop, N}"
seg_end = seg_start+mask_len
conditioning_input[:,:,seg_start:seg_end] = 0
else:
seg_start = cond_start + N // 2
seg_end = seg_start
if self.segregrate_conditioning_segments:
cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb)
cs = cond_enc1[:,:,cond_start]
cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb)
ce = cond_enc2[:,:,(N+cond_start)-seg_end]
else:
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
cs = cond_enc[:,:,cond_start]
ce = cond_enc[:,:,N+cond_start]
assert conditioning_input.shape[-1] - cond_start - N > 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
cond_pre = conditioning_input[:,:,:cond_start]
cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
cond_post = conditioning_input[:,:,N+cond_start:]
# Break up conditioning input into two random segments aligned with the input.
MIN_MARGIN = 8
assert N > (MIN_MARGIN*2+4), f"Input size too small. Was {N} but requires at least {MIN_MARGIN*2+4}"
break_pt = random.randint(2, N-MIN_MARGIN*2) + MIN_MARGIN
cond_left = cond_aligned[:,:,:break_pt]
cond_right = cond_aligned[:,:,break_pt:]
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
to_remove = random.randint(0, cond_left.shape[-1]-MIN_MARGIN)
cond_left = cond_left[:,:,:-to_remove]
to_remove = random.randint(0, cond_right.shape[-1]-MIN_MARGIN)
cond_right = cond_right[:,:,to_remove:]
# Concatenate the _pre and _post back on.
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
cond_right_full = torch.cat([cond_right, cond_post], dim=-1)
# Propagate through the encoder.
cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb)
cs = cond_left_enc[:,:,cond_start]
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
return cond
@ -286,8 +287,7 @@ def test_cheater_model():
# For music:
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
unconditioned_percentage=.4, conditioning_masking=.4,
segregrate_conditioning_segments=True)
unconditioned_percentage=.4)
print_network(model)
for k in range(100):
o = model(clip, ts, cl)