forked from mrq/DL-Art-School
Rework tfdpc_v5 further..
This commit is contained in:
parent
47f04ff5c2
commit
e5859acff7
|
@ -126,6 +126,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
segregrate_conditioning_segments=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
conditioning_masking=0,
|
conditioning_masking=0,
|
||||||
|
@ -136,6 +137,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.time_embed_dim = time_embed_dim
|
self.time_embed_dim = time_embed_dim
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
self.segregrate_conditioning_segments = segregrate_conditioning_segments
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.conditioning_masking = conditioning_masking
|
self.conditioning_masking = conditioning_masking
|
||||||
|
@ -195,6 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, custom_conditioning_fetcher):
|
||||||
|
if custom_conditioning_fetcher is not None:
|
||||||
|
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
||||||
|
else:
|
||||||
|
if self.training and self.conditioning_masking > 0:
|
||||||
|
mask_prop = random.random() * self.conditioning_masking
|
||||||
|
mask_len = min(int(N * mask_prop), 4)
|
||||||
|
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
|
||||||
|
seg_start = random.randint(8, (N-mask_len)) + cond_start
|
||||||
|
seg_end = seg_start+mask_len
|
||||||
|
conditioning_input[:,:,seg_start:seg_end] = 0
|
||||||
|
else:
|
||||||
|
seg_start = cond_start + N // 2
|
||||||
|
seg_end = seg_start
|
||||||
|
if self.segregrate_conditioning_segments:
|
||||||
|
cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb)
|
||||||
|
cs = cond_enc1[:,:,cond_start]
|
||||||
|
cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb)
|
||||||
|
ce = cond_enc2[:,:,(N+cond_start)-seg_end]
|
||||||
|
else:
|
||||||
|
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
|
||||||
|
cs = cond_enc[:,:,cond_start]
|
||||||
|
ce = cond_enc[:,:,N+cond_start]
|
||||||
|
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
||||||
|
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
|
||||||
|
return cond
|
||||||
|
|
||||||
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
|
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
|
||||||
unused_params = []
|
unused_params = []
|
||||||
|
|
||||||
|
@ -204,21 +233,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
cond = self.unconditioned_embedding
|
cond = self.unconditioned_embedding
|
||||||
cond = cond.repeat(1,x.shape[-1],1)
|
cond = cond.repeat(1,x.shape[-1],1)
|
||||||
else:
|
else:
|
||||||
if custom_conditioning_fetcher is not None:
|
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, custom_conditioning_fetcher)
|
||||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
|
||||||
else:
|
|
||||||
if self.training and self.conditioning_masking > 0:
|
|
||||||
cond_op_len = x.shape[-1]
|
|
||||||
mask_prop = random.random() * self.conditioning_masking
|
|
||||||
mask_len = int(cond_op_len * mask_prop)
|
|
||||||
if mask_len > 0:
|
|
||||||
start = random.randint(0, (cond_op_len-mask_len)) + cond_start
|
|
||||||
conditioning_input[:,:,start:(start+mask_len)] = 0
|
|
||||||
cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
|
|
||||||
cs = cond_enc[:,:,cond_start]
|
|
||||||
ce = cond_enc[:,:,x.shape[-1]+cond_start]
|
|
||||||
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
|
||||||
cond = F.interpolate(cond_enc, size=(x.shape[-1],), mode='linear').permute(0,2,1)
|
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
if self.training and self.unconditioned_percentage > 0:
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
|
unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
|
||||||
|
@ -261,15 +276,17 @@ def register_tfdpc5(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
def test_cheater_model():
|
def test_cheater_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 200)
|
||||||
cl = torch.randn(2, 256, 400)
|
cl = torch.randn(2, 256, 500)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
|
|
||||||
# For music:
|
# For music:
|
||||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||||
unconditioned_percentage=.4, conditioning_masking=.5)
|
unconditioned_percentage=.4, conditioning_masking=.5,
|
||||||
|
segregrate_conditioning_segments=True)
|
||||||
print_network(model)
|
print_network(model)
|
||||||
|
for k in range(100):
|
||||||
o = model(clip, ts, cl)
|
o = model(clip, ts, cl)
|
||||||
pg = model.get_grad_norm_parameter_groups()
|
pg = model.get_grad_norm_parameter_groups()
|
||||||
def prmsz(lp):
|
def prmsz(lp):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user