Improve conditioning separation logic

This commit is contained in:
James Betker 2022-07-05 10:30:28 -06:00
parent 802998674e
commit 2b128730e7

View File

@ -126,10 +126,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
num_heads=8, num_heads=8,
dropout=0, dropout=0,
use_fp16=False, use_fp16=False,
segregrate_conditioning_segments=False,
# Parameters for regularization. # Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
conditioning_masking=0,
): ):
super().__init__() super().__init__()
@ -137,10 +135,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
self.model_channels = model_channels self.model_channels = model_channels
self.time_embed_dim = time_embed_dim self.time_embed_dim = time_embed_dim
self.out_channels = out_channels self.out_channels = out_channels
self.segregrate_conditioning_segments = segregrate_conditioning_segments
self.dropout = dropout self.dropout = dropout
self.unconditioned_percentage = unconditioned_percentage self.unconditioned_percentage = unconditioned_percentage
self.conditioning_masking = conditioning_masking
self.enable_fp16 = use_fp16 self.enable_fp16 = use_fp16
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
@ -201,28 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
if custom_conditioning_fetcher is not None: if custom_conditioning_fetcher is not None:
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
else: else:
if self.training and self.conditioning_masking > 0: assert conditioning_input.shape[-1] - cond_start - N > 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
mask_prop = random.random() * self.conditioning_masking cond_pre = conditioning_input[:,:,:cond_start]
mask_len = max(int(N * mask_prop), 16) cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}" cond_post = conditioning_input[:,:,N+cond_start:]
seg_start = random.randint(8, (N-mask_len)) + cond_start
# Readjust mask_len to ensure at least 8 sequence elements on the end as well # Break up conditioning input into two random segments aligned with the input.
mask_len = min(mask_len, conditioning_input.shape[-1]-seg_start-8) MIN_MARGIN = 8
assert mask_len > 0, f"Error with mask_len: {conditioning_input.shape[-1], seg_start, mask_len, mask_prop, N}" assert N > (MIN_MARGIN*2+4), f"Input size too small. Was {N} but requires at least {MIN_MARGIN*2+4}"
seg_end = seg_start+mask_len break_pt = random.randint(2, N-MIN_MARGIN*2) + MIN_MARGIN
conditioning_input[:,:,seg_start:seg_end] = 0 cond_left = cond_aligned[:,:,:break_pt]
else: cond_right = cond_aligned[:,:,break_pt:]
seg_start = cond_start + N // 2
seg_end = seg_start # Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
if self.segregrate_conditioning_segments: to_remove = random.randint(0, cond_left.shape[-1]-MIN_MARGIN)
cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb) cond_left = cond_left[:,:,:-to_remove]
cs = cond_enc1[:,:,cond_start] to_remove = random.randint(0, cond_right.shape[-1]-MIN_MARGIN)
cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb) cond_right = cond_right[:,:,to_remove:]
ce = cond_enc2[:,:,(N+cond_start)-seg_end]
else: # Concatenate the _pre and _post back on.
cond_enc = self.conditioning_encoder(conditioning_input, time_emb) cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
cs = cond_enc[:,:,cond_start] cond_right_full = torch.cat([cond_right, cond_post], dim=-1)
ce = cond_enc[:,:,N+cond_start]
# Propagate through the encoder.
cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb)
cs = cond_left_enc[:,:,cond_start]
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1) cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
return cond return cond
@ -286,8 +287,7 @@ def test_cheater_model():
# For music: # For music:
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=512, num_heads=8, num_layers=15, dropout=0, contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
unconditioned_percentage=.4, conditioning_masking=.4, unconditioned_percentage=.4)
segregrate_conditioning_segments=True)
print_network(model) print_network(model)
for k in range(100): for k in range(100):
o = model(clip, ts, cl) o = model(clip, ts, cl)