Add experimental gradient boosting into tts7

This commit is contained in:
James Betker 2022-03-03 21:51:40 -07:00
parent 7ea84f1ac3
commit 6af5d129ce

View File

@ -204,6 +204,8 @@ class DiffusionTts(nn.Module):
enabled_unaligned_inputs=False,
num_unaligned_tokens=164,
unaligned_encoder_depth=8,
# Experimental parameters
component_gradient_boosting=False,
):
super().__init__()
@ -229,6 +231,7 @@ class DiffusionTts(nn.Module):
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.component_gradient_boosting = component_gradient_boosting
padding = 1 if kernel_size == 3 else 2
time_embed_dim = model_channels * time_embed_dim_multiplier
@ -419,6 +422,22 @@ class DiffusionTts(nn.Module):
groups['unaligned_encoder'] = list(self.unaligned_encoder.parameters())
return groups
def before_step(self, it):
if not self.component_gradient_boosting:
return
MIN_PROPORTIONAL_BOOST_LEVEL = .5
components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()),
list(self.unaligned_encoder.parameters())]
input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2)
output_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.output_blocks.parameters()]), 2)
diffusion_norm = (input_norm + output_norm) / 2
min_norm = diffusion_norm * MIN_PROPORTIONAL_BOOST_LEVEL
for component in components:
norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2)
if norm < min_norm:
mult = min_norm / (norm + 1e-8)
for p in component:
p.grad.data.mul_(mult)
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False):
"""
@ -532,8 +551,10 @@ if __name__ == '__main__':
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4, super_sampling=False,
enabled_unaligned_inputs=True)
model(clip, ts, tok, cond, lr, un)
model(clip, ts, None, cond, lr)
enabled_unaligned_inputs=True,
component_gradient_boosting=True)
o = model(clip, ts, tok, cond, lr, un)
o.sum().backward()
model.before_step(0)
torch.save(model.state_dict(), 'test_out.pth')