Rework tfdpc_v5 further..

This commit is contained in:
James Betker 2022-07-03 18:19:01 -06:00
parent 47f04ff5c2
commit e5859acff7

View File

@ -126,6 +126,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
num_heads=8,
dropout=0,
use_fp16=False,
segregrate_conditioning_segments=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
conditioning_masking=0,
@ -136,6 +137,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
self.model_channels = model_channels
self.time_embed_dim = time_embed_dim
self.out_channels = out_channels
self.segregrate_conditioning_segments = segregrate_conditioning_segments
self.dropout = dropout
self.unconditioned_percentage = unconditioned_percentage
self.conditioning_masking = conditioning_masking
@ -195,6 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
}
return groups
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, custom_conditioning_fetcher):
if custom_conditioning_fetcher is not None:
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
else:
if self.training and self.conditioning_masking > 0:
mask_prop = random.random() * self.conditioning_masking
mask_len = min(int(N * mask_prop), 4)
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
seg_start = random.randint(8, (N-mask_len)) + cond_start
seg_end = seg_start+mask_len
conditioning_input[:,:,seg_start:seg_end] = 0
else:
seg_start = cond_start + N // 2
seg_end = seg_start
if self.segregrate_conditioning_segments:
cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb)
cs = cond_enc1[:,:,cond_start]
cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb)
ce = cond_enc2[:,:,(N+cond_start)-seg_end]
else:
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
cs = cond_enc[:,:,cond_start]
ce = cond_enc[:,:,N+cond_start]
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
return cond
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
unused_params = []
@ -204,21 +233,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
cond = self.unconditioned_embedding
cond = cond.repeat(1,x.shape[-1],1)
else:
if custom_conditioning_fetcher is not None:
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
else:
if self.training and self.conditioning_masking > 0:
cond_op_len = x.shape[-1]
mask_prop = random.random() * self.conditioning_masking
mask_len = int(cond_op_len * mask_prop)
if mask_len > 0:
start = random.randint(0, (cond_op_len-mask_len)) + cond_start
conditioning_input[:,:,start:(start+mask_len)] = 0
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
cs = cond_enc[:,:,cond_start]
ce = cond_enc[:,:,x.shape[-1]+cond_start]
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
cond = F.interpolate(cond_enc, size=(x.shape[-1],), mode='linear').permute(0,2,1)
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, custom_conditioning_fetcher)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
@ -261,16 +276,18 @@ def register_tfdpc5(opt_net, opt):
def test_cheater_model():
clip = torch.randn(2, 256, 400)
cl = torch.randn(2, 256, 400)
clip = torch.randn(2, 256, 200)
cl = torch.randn(2, 256, 500)
ts = torch.LongTensor([600, 600])
# For music:
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
unconditioned_percentage=.4, conditioning_masking=.5)
unconditioned_percentage=.4, conditioning_masking=.5,
segregrate_conditioning_segments=True)
print_network(model)
o = model(clip, ts, cl)
for k in range(100):
o = model(clip, ts, cl)
pg = model.get_grad_norm_parameter_groups()
def prmsz(lp):
sz = 0