forked from mrq/DL-Art-School
Cap grad booster
This commit is contained in:
parent
2d1cb83c1d
commit
77c18b53b3
|
@ -426,6 +426,7 @@ class DiffusionTts(nn.Module):
|
||||||
if not self.component_gradient_boosting:
|
if not self.component_gradient_boosting:
|
||||||
return
|
return
|
||||||
MIN_PROPORTIONAL_BOOST_LEVEL = .5
|
MIN_PROPORTIONAL_BOOST_LEVEL = .5
|
||||||
|
MAX_MULTIPLIER = 100
|
||||||
components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()),
|
components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()),
|
||||||
list(self.unaligned_encoder.parameters())]
|
list(self.unaligned_encoder.parameters())]
|
||||||
input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2)
|
input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2)
|
||||||
|
@ -436,6 +437,7 @@ class DiffusionTts(nn.Module):
|
||||||
norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2)
|
norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2)
|
||||||
if norm < min_norm:
|
if norm < min_norm:
|
||||||
mult = min_norm / (norm + 1e-8)
|
mult = min_norm / (norm + 1e-8)
|
||||||
|
mult = min(mult, MAX_MULTIPLIER)
|
||||||
for p in component:
|
for p in component:
|
||||||
p.grad.data.mul_(mult)
|
p.grad.data.mul_(mult)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user