Add experimental gradient boosting into tts7
This commit is contained in:
parent
7ea84f1ac3
commit
6af5d129ce
|
@ -204,6 +204,8 @@ class DiffusionTts(nn.Module):
|
|||
enabled_unaligned_inputs=False,
|
||||
num_unaligned_tokens=164,
|
||||
unaligned_encoder_depth=8,
|
||||
# Experimental parameters
|
||||
component_gradient_boosting=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -229,6 +231,7 @@ class DiffusionTts(nn.Module):
|
|||
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.enable_fp16 = use_fp16
|
||||
self.component_gradient_boosting = component_gradient_boosting
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||
|
@ -419,6 +422,22 @@ class DiffusionTts(nn.Module):
|
|||
groups['unaligned_encoder'] = list(self.unaligned_encoder.parameters())
|
||||
return groups
|
||||
|
||||
def before_step(self, it):
|
||||
if not self.component_gradient_boosting:
|
||||
return
|
||||
MIN_PROPORTIONAL_BOOST_LEVEL = .5
|
||||
components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()),
|
||||
list(self.unaligned_encoder.parameters())]
|
||||
input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2)
|
||||
output_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.output_blocks.parameters()]), 2)
|
||||
diffusion_norm = (input_norm + output_norm) / 2
|
||||
min_norm = diffusion_norm * MIN_PROPORTIONAL_BOOST_LEVEL
|
||||
for component in components:
|
||||
norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2)
|
||||
if norm < min_norm:
|
||||
mult = min_norm / (norm + 1e-8)
|
||||
for p in component:
|
||||
p.grad.data.mul_(mult)
|
||||
|
||||
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False):
|
||||
"""
|
||||
|
@ -532,8 +551,10 @@ if __name__ == '__main__':
|
|||
kernel_size=3,
|
||||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4, super_sampling=False,
|
||||
enabled_unaligned_inputs=True)
|
||||
model(clip, ts, tok, cond, lr, un)
|
||||
model(clip, ts, None, cond, lr)
|
||||
enabled_unaligned_inputs=True,
|
||||
component_gradient_boosting=True)
|
||||
o = model(clip, ts, tok, cond, lr, un)
|
||||
o.sum().backward()
|
||||
model.before_step(0)
|
||||
torch.save(model.state_dict(), 'test_out.pth')
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user