Rework tfdpc_v5 further..
This commit is contained in:
parent
47f04ff5c2
commit
e5859acff7
|
@ -126,6 +126,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
num_heads=8,
|
||||
dropout=0,
|
||||
use_fp16=False,
|
||||
segregrate_conditioning_segments=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
conditioning_masking=0,
|
||||
|
@ -136,6 +137,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
self.model_channels = model_channels
|
||||
self.time_embed_dim = time_embed_dim
|
||||
self.out_channels = out_channels
|
||||
self.segregrate_conditioning_segments = segregrate_conditioning_segments
|
||||
self.dropout = dropout
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.conditioning_masking = conditioning_masking
|
||||
|
@ -195,6 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, custom_conditioning_fetcher):
|
||||
if custom_conditioning_fetcher is not None:
|
||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
||||
else:
|
||||
if self.training and self.conditioning_masking > 0:
|
||||
mask_prop = random.random() * self.conditioning_masking
|
||||
mask_len = min(int(N * mask_prop), 4)
|
||||
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
|
||||
seg_start = random.randint(8, (N-mask_len)) + cond_start
|
||||
seg_end = seg_start+mask_len
|
||||
conditioning_input[:,:,seg_start:seg_end] = 0
|
||||
else:
|
||||
seg_start = cond_start + N // 2
|
||||
seg_end = seg_start
|
||||
if self.segregrate_conditioning_segments:
|
||||
cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb)
|
||||
cs = cond_enc1[:,:,cond_start]
|
||||
cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb)
|
||||
ce = cond_enc2[:,:,(N+cond_start)-seg_end]
|
||||
else:
|
||||
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
|
||||
cs = cond_enc[:,:,cond_start]
|
||||
ce = cond_enc[:,:,N+cond_start]
|
||||
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
||||
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
|
||||
return cond
|
||||
|
||||
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
|
||||
unused_params = []
|
||||
|
||||
|
@ -204,21 +233,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
cond = self.unconditioned_embedding
|
||||
cond = cond.repeat(1,x.shape[-1],1)
|
||||
else:
|
||||
if custom_conditioning_fetcher is not None:
|
||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
||||
else:
|
||||
if self.training and self.conditioning_masking > 0:
|
||||
cond_op_len = x.shape[-1]
|
||||
mask_prop = random.random() * self.conditioning_masking
|
||||
mask_len = int(cond_op_len * mask_prop)
|
||||
if mask_len > 0:
|
||||
start = random.randint(0, (cond_op_len-mask_len)) + cond_start
|
||||
conditioning_input[:,:,start:(start+mask_len)] = 0
|
||||
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
|
||||
cs = cond_enc[:,:,cond_start]
|
||||
ce = cond_enc[:,:,x.shape[-1]+cond_start]
|
||||
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
||||
cond = F.interpolate(cond_enc, size=(x.shape[-1],), mode='linear').permute(0,2,1)
|
||||
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, custom_conditioning_fetcher)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
|
||||
|
@ -261,16 +276,18 @@ def register_tfdpc5(opt_net, opt):
|
|||
|
||||
|
||||
def test_cheater_model():
|
||||
clip = torch.randn(2, 256, 400)
|
||||
cl = torch.randn(2, 256, 400)
|
||||
clip = torch.randn(2, 256, 200)
|
||||
cl = torch.randn(2, 256, 500)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||
unconditioned_percentage=.4, conditioning_masking=.5)
|
||||
unconditioned_percentage=.4, conditioning_masking=.5,
|
||||
segregrate_conditioning_segments=True)
|
||||
print_network(model)
|
||||
o = model(clip, ts, cl)
|
||||
for k in range(100):
|
||||
o = model(clip, ts, cl)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
def prmsz(lp):
|
||||
sz = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user