diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index b17fc056..fc3e620a 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -426,6 +426,7 @@ class DiffusionTts(nn.Module): if not self.component_gradient_boosting: return MIN_PROPORTIONAL_BOOST_LEVEL = .5 + MAX_MULTIPLIER = 100 components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()), list(self.unaligned_encoder.parameters())] input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2) @@ -436,6 +437,7 @@ class DiffusionTts(nn.Module): norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2) if norm < min_norm: mult = min_norm / (norm + 1e-8) + mult = min(mult, MAX_MULTIPLIER) for p in component: p.grad.data.mul_(mult)