tfdpcv5 updates

This commit is contained in:
James Betker 2022-07-12 21:48:18 -06:00
parent ce82eb6022
commit f46d6645da
3 changed files with 51 additions and 34 deletions

View File

@ -114,9 +114,6 @@ class ConditioningEncoder(nn.Module):
class TransformerDiffusionWithPointConditioning(nn.Module):
"""
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
"""
def __init__(
self,
in_channels=256,
@ -129,9 +126,11 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
input_cond_dim=1024,
num_heads=8,
dropout=0,
time_proj=False,
time_proj=True,
new_cond=False,
use_fp16=False,
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
regularization=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
):
@ -144,6 +143,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
self.dropout = dropout
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.regularization = regularization
self.new_cond = new_cond
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj)
@ -166,13 +167,11 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
cond_projection=(k % 3 == 0),
use_conv=(k % 3 != 0),
) for k in range(num_layers)])
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
)
self.debug_codes = {}
def get_grad_norm_parameter_groups(self):
@ -199,10 +198,24 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
}
return groups
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, custom_conditioning_fetcher):
if custom_conditioning_fetcher is not None:
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
else:
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, cond_left, cond_right):
if self.training and self.regularization:
# frequency regularization
fstart = random.randint(0, conditioning_input.shape[1] - 1)
fclip = random.randint(1, min(conditioning_input.shape[1]-fstart, 16))
conditioning_input[:,fstart:fstart+fclip] = 0
# time regularization
for k in range(1, random.randint(2, 4)):
tstart = random.randint(0, conditioning_input.shape[-1] - 1)
tclip = random.randint(1, min(conditioning_input.shape[-1]-tstart, 10))
conditioning_input[:,:,tstart:tstart+tclip] = 0
if cond_left is None and self.new_cond:
cond_left = conditioning_input[:,:,:max(cond_start, 20)]
left_pt = cond_start
cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):]
right_pt = cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start))
elif cond_left is None:
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
cond_pre = conditioning_input[:,:,:cond_start]
cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
@ -223,19 +236,25 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
cond_right = cond_right[:,:,to_remove_right:]
# Concatenate the _pre and _post back on.
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
cond_right_full = torch.cat([cond_right, cond_post], dim=-1)
left_pt = cond_start
right_pt = cond_right.shape[-1]
cond_left = torch.cat([cond_pre, cond_left], dim=-1)
cond_right = torch.cat([cond_right, cond_post], dim=-1)
else:
left_pt = -1
right_pt = 0
# Propagate through the encoder.
cond_left_enc = self.conditioning_encoder(cond_left, time_emb)
cs = cond_left_enc[:,:,left_pt]
cond_right_enc = self.conditioning_encoder(cond_right, time_emb)
ce = cond_right_enc[:,:,right_pt]
# Propagate through the encoder.
cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb)
cs = cond_left_enc[:,:,cond_start]
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
cond = F.interpolate(cond_enc, size=(N,), mode='linear', align_corners=True).permute(0,2,1)
return cond
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
def forward(self, x, timesteps, conditioning_input=None, cond_left=None, cond_right=None, conditioning_free=False, cond_start=0):
unused_params = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
@ -244,7 +263,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
cond = self.unconditioned_embedding
cond = cond.repeat(1,x.shape[-1],1)
else:
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, custom_conditioning_fetcher)
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, cond_left, cond_right)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
@ -276,9 +295,10 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
if not self.new_cond: # Not really related, I just don't want to add a new config.
for p in scaled_grad_parameters:
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
@register_model
@ -293,11 +313,12 @@ def test_cheater_model():
# For music:
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
unconditioned_percentage=.4, checkpoint_conditioning=False)
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
unconditioned_percentage=.4, checkpoint_conditioning=False,
regularization=True, new_cond=True)
print_network(model)
for k in range(100):
o = model(clip, ts, cl)
for cs in range(276,cl.shape[-1]-clip.shape[-1]):
o = model(clip, ts, cl, cond_start=cs)
pg = model.get_grad_norm_parameter_groups()
def prmsz(lp):
sz = 0

View File

@ -617,8 +617,6 @@ class GaussianDiffusion:
mask,
noise=None,
clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
@ -640,8 +638,6 @@ class GaussianDiffusion:
img,
t,
clip_denoised=clip_denoised,
causal=causal,
causal_slope=causal_slope,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,

View File

@ -436,18 +436,18 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\53000_generator_ema.pth'
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\80500_generator_ema.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 220, # basis: 192
'diffusion_steps': 256, # basis: 192
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
# Slope 1: 1.03x, 2: 1.06, 4: 1.135, 8: 1.27, 16: 1.54
'causal': True, 'causal_slope': 3, # DONT FORGET TO INCREMENT THE STEP!
'causal': True, 'causal_slope': 4, # DONT FORGET TO INCREMENT THE STEP!
#'partial_low': 128, 'partial_high': 192
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 3, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 104, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
fds = []
for i in range(2):