Cap grad booster

This commit is contained in:
James Betker 2022-03-04 10:40:24 -07:00
parent 2d1cb83c1d
commit 77c18b53b3

View File

@ -426,6 +426,7 @@ class DiffusionTts(nn.Module):
if not self.component_gradient_boosting:
return
MIN_PROPORTIONAL_BOOST_LEVEL = .5
MAX_MULTIPLIER = 100
components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()),
list(self.unaligned_encoder.parameters())]
input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2)
@ -436,6 +437,7 @@ class DiffusionTts(nn.Module):
norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2)
if norm < min_norm:
mult = min_norm / (norm + 1e-8)
mult = min(mult, MAX_MULTIPLIER)
for p in component:
p.grad.data.mul_(mult)