forked from mrq/DL-Art-School
Cap grad booster
This commit is contained in:
parent
2d1cb83c1d
commit
77c18b53b3
|
@ -426,6 +426,7 @@ class DiffusionTts(nn.Module):
|
|||
if not self.component_gradient_boosting:
|
||||
return
|
||||
MIN_PROPORTIONAL_BOOST_LEVEL = .5
|
||||
MAX_MULTIPLIER = 100
|
||||
components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()),
|
||||
list(self.unaligned_encoder.parameters())]
|
||||
input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2)
|
||||
|
@ -436,6 +437,7 @@ class DiffusionTts(nn.Module):
|
|||
norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2)
|
||||
if norm < min_norm:
|
||||
mult = min_norm / (norm + 1e-8)
|
||||
mult = min(mult, MAX_MULTIPLIER)
|
||||
for p in component:
|
||||
p.grad.data.mul_(mult)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user