Improve conditioning separation logic
This commit is contained in:
parent
802998674e
commit
2b128730e7
|
@ -126,10 +126,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
segregrate_conditioning_segments=False,
|
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
conditioning_masking=0,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -137,10 +135,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.time_embed_dim = time_embed_dim
|
self.time_embed_dim = time_embed_dim
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.segregrate_conditioning_segments = segregrate_conditioning_segments
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.conditioning_masking = conditioning_masking
|
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||||
|
@ -201,28 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
if custom_conditioning_fetcher is not None:
|
if custom_conditioning_fetcher is not None:
|
||||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
||||||
else:
|
else:
|
||||||
if self.training and self.conditioning_masking > 0:
|
assert conditioning_input.shape[-1] - cond_start - N > 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
||||||
mask_prop = random.random() * self.conditioning_masking
|
cond_pre = conditioning_input[:,:,:cond_start]
|
||||||
mask_len = max(int(N * mask_prop), 16)
|
cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
|
||||||
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
|
cond_post = conditioning_input[:,:,N+cond_start:]
|
||||||
seg_start = random.randint(8, (N-mask_len)) + cond_start
|
|
||||||
# Readjust mask_len to ensure at least 8 sequence elements on the end as well
|
# Break up conditioning input into two random segments aligned with the input.
|
||||||
mask_len = min(mask_len, conditioning_input.shape[-1]-seg_start-8)
|
MIN_MARGIN = 8
|
||||||
assert mask_len > 0, f"Error with mask_len: {conditioning_input.shape[-1], seg_start, mask_len, mask_prop, N}"
|
assert N > (MIN_MARGIN*2+4), f"Input size too small. Was {N} but requires at least {MIN_MARGIN*2+4}"
|
||||||
seg_end = seg_start+mask_len
|
break_pt = random.randint(2, N-MIN_MARGIN*2) + MIN_MARGIN
|
||||||
conditioning_input[:,:,seg_start:seg_end] = 0
|
cond_left = cond_aligned[:,:,:break_pt]
|
||||||
else:
|
cond_right = cond_aligned[:,:,break_pt:]
|
||||||
seg_start = cond_start + N // 2
|
|
||||||
seg_end = seg_start
|
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
|
||||||
if self.segregrate_conditioning_segments:
|
to_remove = random.randint(0, cond_left.shape[-1]-MIN_MARGIN)
|
||||||
cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb)
|
cond_left = cond_left[:,:,:-to_remove]
|
||||||
cs = cond_enc1[:,:,cond_start]
|
to_remove = random.randint(0, cond_right.shape[-1]-MIN_MARGIN)
|
||||||
cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb)
|
cond_right = cond_right[:,:,to_remove:]
|
||||||
ce = cond_enc2[:,:,(N+cond_start)-seg_end]
|
|
||||||
else:
|
# Concatenate the _pre and _post back on.
|
||||||
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
|
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
|
||||||
cs = cond_enc[:,:,cond_start]
|
cond_right_full = torch.cat([cond_right, cond_post], dim=-1)
|
||||||
ce = cond_enc[:,:,N+cond_start]
|
|
||||||
|
# Propagate through the encoder.
|
||||||
|
cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb)
|
||||||
|
cs = cond_left_enc[:,:,cond_start]
|
||||||
|
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
|
||||||
|
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
|
||||||
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
||||||
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
|
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
|
||||||
return cond
|
return cond
|
||||||
|
@ -286,8 +287,7 @@ def test_cheater_model():
|
||||||
# For music:
|
# For music:
|
||||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||||
unconditioned_percentage=.4, conditioning_masking=.4,
|
unconditioned_percentage=.4)
|
||||||
segregrate_conditioning_segments=True)
|
|
||||||
print_network(model)
|
print_network(model)
|
||||||
for k in range(100):
|
for k in range(100):
|
||||||
o = model(clip, ts, cl)
|
o = model(clip, ts, cl)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user