tfdpcv5 updates
This commit is contained in:
parent
ce82eb6022
commit
f46d6645da
|
@ -114,9 +114,6 @@ class ConditioningEncoder(nn.Module):
|
|||
|
||||
|
||||
class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||
"""
|
||||
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=256,
|
||||
|
@ -129,9 +126,11 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
input_cond_dim=1024,
|
||||
num_heads=8,
|
||||
dropout=0,
|
||||
time_proj=False,
|
||||
time_proj=True,
|
||||
new_cond=False,
|
||||
use_fp16=False,
|
||||
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
|
||||
regularization=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
):
|
||||
|
@ -144,6 +143,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
self.dropout = dropout
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.enable_fp16 = use_fp16
|
||||
self.regularization = regularization
|
||||
self.new_cond = new_cond
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj)
|
||||
|
@ -166,13 +167,11 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
cond_projection=(k % 3 == 0),
|
||||
use_conv=(k % 3 != 0),
|
||||
) for k in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
self.debug_codes = {}
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
|
@ -199,10 +198,24 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, custom_conditioning_fetcher):
|
||||
if custom_conditioning_fetcher is not None:
|
||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
||||
else:
|
||||
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, cond_left, cond_right):
|
||||
if self.training and self.regularization:
|
||||
# frequency regularization
|
||||
fstart = random.randint(0, conditioning_input.shape[1] - 1)
|
||||
fclip = random.randint(1, min(conditioning_input.shape[1]-fstart, 16))
|
||||
conditioning_input[:,fstart:fstart+fclip] = 0
|
||||
# time regularization
|
||||
for k in range(1, random.randint(2, 4)):
|
||||
tstart = random.randint(0, conditioning_input.shape[-1] - 1)
|
||||
tclip = random.randint(1, min(conditioning_input.shape[-1]-tstart, 10))
|
||||
conditioning_input[:,:,tstart:tstart+tclip] = 0
|
||||
|
||||
if cond_left is None and self.new_cond:
|
||||
cond_left = conditioning_input[:,:,:max(cond_start, 20)]
|
||||
left_pt = cond_start
|
||||
cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):]
|
||||
right_pt = cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start))
|
||||
elif cond_left is None:
|
||||
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
||||
cond_pre = conditioning_input[:,:,:cond_start]
|
||||
cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
|
||||
|
@ -223,19 +236,25 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
cond_right = cond_right[:,:,to_remove_right:]
|
||||
|
||||
# Concatenate the _pre and _post back on.
|
||||
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
|
||||
cond_right_full = torch.cat([cond_right, cond_post], dim=-1)
|
||||
left_pt = cond_start
|
||||
right_pt = cond_right.shape[-1]
|
||||
cond_left = torch.cat([cond_pre, cond_left], dim=-1)
|
||||
cond_right = torch.cat([cond_right, cond_post], dim=-1)
|
||||
else:
|
||||
left_pt = -1
|
||||
right_pt = 0
|
||||
|
||||
# Propagate through the encoder.
|
||||
cond_left_enc = self.conditioning_encoder(cond_left, time_emb)
|
||||
cs = cond_left_enc[:,:,left_pt]
|
||||
cond_right_enc = self.conditioning_encoder(cond_right, time_emb)
|
||||
ce = cond_right_enc[:,:,right_pt]
|
||||
|
||||
# Propagate through the encoder.
|
||||
cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb)
|
||||
cs = cond_left_enc[:,:,cond_start]
|
||||
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
|
||||
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
|
||||
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
||||
cond = F.interpolate(cond_enc, size=(N,), mode='linear', align_corners=True).permute(0,2,1)
|
||||
return cond
|
||||
|
||||
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
|
||||
def forward(self, x, timesteps, conditioning_input=None, cond_left=None, cond_right=None, conditioning_free=False, cond_start=0):
|
||||
unused_params = []
|
||||
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||
|
@ -244,7 +263,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
cond = self.unconditioned_embedding
|
||||
cond = cond.repeat(1,x.shape[-1],1)
|
||||
else:
|
||||
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, custom_conditioning_fetcher)
|
||||
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, cond_left, cond_right)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
|
||||
|
@ -276,9 +295,10 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||
# directly fiddling with the gradients.
|
||||
for p in scaled_grad_parameters:
|
||||
if hasattr(p, 'grad') and p.grad is not None:
|
||||
p.grad *= .2
|
||||
if not self.new_cond: # Not really related, I just don't want to add a new config.
|
||||
for p in scaled_grad_parameters:
|
||||
if hasattr(p, 'grad') and p.grad is not None:
|
||||
p.grad *= .2
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -293,11 +313,12 @@ def test_cheater_model():
|
|||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||
unconditioned_percentage=.4, checkpoint_conditioning=False)
|
||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||
unconditioned_percentage=.4, checkpoint_conditioning=False,
|
||||
regularization=True, new_cond=True)
|
||||
print_network(model)
|
||||
for k in range(100):
|
||||
o = model(clip, ts, cl)
|
||||
for cs in range(276,cl.shape[-1]-clip.shape[-1]):
|
||||
o = model(clip, ts, cl, cond_start=cs)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
def prmsz(lp):
|
||||
sz = 0
|
||||
|
|
|
@ -617,8 +617,6 @@ class GaussianDiffusion:
|
|||
mask,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
causal=False,
|
||||
causal_slope=1,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
|
@ -640,8 +638,6 @@ class GaussianDiffusion:
|
|||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
causal=causal,
|
||||
causal_slope=causal_slope,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
|
|
|
@ -436,18 +436,18 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\53000_generator_ema.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\80500_generator_ema.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 220, # basis: 192
|
||||
'diffusion_steps': 256, # basis: 192
|
||||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
|
||||
# Slope 1: 1.03x, 2: 1.06, 4: 1.135, 8: 1.27, 16: 1.54
|
||||
'causal': True, 'causal_slope': 3, # DONT FORGET TO INCREMENT THE STEP!
|
||||
'causal': True, 'causal_slope': 4, # DONT FORGET TO INCREMENT THE STEP!
|
||||
#'partial_low': 128, 'partial_high': 192
|
||||
}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 3, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 104, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
fds = []
|
||||
for i in range(2):
|
||||
|
|
Loading…
Reference in New Issue
Block a user