From 6af5d129ce90dcc3578a37737cfc02a2e07b7a25 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 3 Mar 2022 21:51:40 -0700 Subject: [PATCH] Add experimental gradient boosting into tts7 --- codes/models/gpt_voice/unet_diffusion_tts7.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index 5c13c7ae..b17fc056 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -204,6 +204,8 @@ class DiffusionTts(nn.Module): enabled_unaligned_inputs=False, num_unaligned_tokens=164, unaligned_encoder_depth=8, + # Experimental parameters + component_gradient_boosting=False, ): super().__init__() @@ -229,6 +231,7 @@ class DiffusionTts(nn.Module): self.super_sampling_max_noising_factor = super_sampling_max_noising_factor self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 + self.component_gradient_boosting = component_gradient_boosting padding = 1 if kernel_size == 3 else 2 time_embed_dim = model_channels * time_embed_dim_multiplier @@ -419,6 +422,22 @@ class DiffusionTts(nn.Module): groups['unaligned_encoder'] = list(self.unaligned_encoder.parameters()) return groups + def before_step(self, it): + if not self.component_gradient_boosting: + return + MIN_PROPORTIONAL_BOOST_LEVEL = .5 + components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()), + list(self.unaligned_encoder.parameters())] + input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2) + output_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.output_blocks.parameters()]), 2) + diffusion_norm = (input_norm + output_norm) / 2 + min_norm = diffusion_norm * MIN_PROPORTIONAL_BOOST_LEVEL + for component in components: + norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2) + if norm < min_norm: + mult = min_norm / (norm + 1e-8) + for p in component: + p.grad.data.mul_(mult) def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False): """ @@ -532,8 +551,10 @@ if __name__ == '__main__': kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, super_sampling=False, - enabled_unaligned_inputs=True) - model(clip, ts, tok, cond, lr, un) - model(clip, ts, None, cond, lr) + enabled_unaligned_inputs=True, + component_gradient_boosting=True) + o = model(clip, ts, tok, cond, lr, un) + o.sum().backward() + model.before_step(0) torch.save(model.state_dict(), 'test_out.pth')