few more tfd13 things

This commit is contained in:
James Betker 2022-07-24 17:39:33 -06:00
parent f3d967dbf5
commit cc62ba9cba
2 changed files with 38 additions and 29 deletions

View File

@ -147,31 +147,6 @@ class TransformerDiffusion(nn.Module):
self.debug_codes = {}
def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
[lyr.block1.ff2.parameters() for lyr in self.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
[lyr.block2.ff2.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
'blk1_attention_layers': attn1,
'blk2_attention_layers': attn2,
'attention_layers': attn1 + attn2,
'blk1_ff_layers': ff1,
'blk2_ff_layers': ff2,
'ff_layers': ff1 + ff2,
'block_out_layers': blkout_layers,
'out': list(self.out.parameters()),
'x_proj': list(self.inp_block.parameters()),
'layers': list(self.layers.parameters()),
'time_embed': list(self.time_embed.parameters()),
'resolution_embed': list(self.resolution_embed.parameters()),
}
return groups
def input_to_random_resolution_and_window(self, x, ts, diffuser):
"""
This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well
@ -271,6 +246,41 @@ class TransformerDiffusion(nn.Module):
return out
def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
[lyr.block1.ff2.parameters() for lyr in self.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
[lyr.block2.ff2.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
'blk1_attention_layers': attn1,
'blk2_attention_layers': attn2,
'attention_layers': attn1 + attn2,
'blk1_ff_layers': ff1,
'blk2_ff_layers': ff2,
'ff_layers': ff1 + ff2,
'block_out_layers': blkout_layers,
'out': list(self.out.parameters()),
'x_proj': list(self.inp_block.parameters()),
'layers': list(self.layers.parameters()),
'time_embed': list(self.time_embed.parameters()),
'prior_time_embed': list(self.prior_time_embed.parameters()),
'resolution_embed': list(self.resolution_embed.parameters()),
}
return groups
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
@register_model
def register_transformer_diffusion13(opt_net, opt):

View File

@ -314,11 +314,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
"""
# For multilevel SR:
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\56000_generator.pth'
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\18000_generator.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
@ -328,7 +327,6 @@ if __name__ == '__main__':
}
"""
# For TFD+cheater trainer
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
also_load_savepoint=False, strict_load=False,
@ -340,8 +338,9 @@ if __name__ == '__main__':
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': True,
'diffusion_schedule': 'cosine', 'diffusion_type': 'from_codes_quant',
}
"""
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 10, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 11, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
fds = []
for i in range(2):