Add experimental gradient boosting into tts7
This commit is contained in:
parent
7ea84f1ac3
commit
6af5d129ce
|
@ -204,6 +204,8 @@ class DiffusionTts(nn.Module):
|
||||||
enabled_unaligned_inputs=False,
|
enabled_unaligned_inputs=False,
|
||||||
num_unaligned_tokens=164,
|
num_unaligned_tokens=164,
|
||||||
unaligned_encoder_depth=8,
|
unaligned_encoder_depth=8,
|
||||||
|
# Experimental parameters
|
||||||
|
component_gradient_boosting=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -229,6 +231,7 @@ class DiffusionTts(nn.Module):
|
||||||
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
|
self.component_gradient_boosting = component_gradient_boosting
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
|
@ -419,6 +422,22 @@ class DiffusionTts(nn.Module):
|
||||||
groups['unaligned_encoder'] = list(self.unaligned_encoder.parameters())
|
groups['unaligned_encoder'] = list(self.unaligned_encoder.parameters())
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
def before_step(self, it):
|
||||||
|
if not self.component_gradient_boosting:
|
||||||
|
return
|
||||||
|
MIN_PROPORTIONAL_BOOST_LEVEL = .5
|
||||||
|
components = [list(self.contextual_embedder.parameters()), list(self.middle_block.parameters()), list(self.conditioning_encoder.parameters()),
|
||||||
|
list(self.unaligned_encoder.parameters())]
|
||||||
|
input_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.input_blocks.parameters()]), 2)
|
||||||
|
output_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in self.output_blocks.parameters()]), 2)
|
||||||
|
diffusion_norm = (input_norm + output_norm) / 2
|
||||||
|
min_norm = diffusion_norm * MIN_PROPORTIONAL_BOOST_LEVEL
|
||||||
|
for component in components:
|
||||||
|
norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in component]), 2)
|
||||||
|
if norm < min_norm:
|
||||||
|
mult = min_norm / (norm + 1e-8)
|
||||||
|
for p in component:
|
||||||
|
p.grad.data.mul_(mult)
|
||||||
|
|
||||||
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False):
|
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False):
|
||||||
"""
|
"""
|
||||||
|
@ -532,8 +551,10 @@ if __name__ == '__main__':
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
time_embed_dim_multiplier=4, super_sampling=False,
|
time_embed_dim_multiplier=4, super_sampling=False,
|
||||||
enabled_unaligned_inputs=True)
|
enabled_unaligned_inputs=True,
|
||||||
model(clip, ts, tok, cond, lr, un)
|
component_gradient_boosting=True)
|
||||||
model(clip, ts, None, cond, lr)
|
o = model(clip, ts, tok, cond, lr, un)
|
||||||
|
o.sum().backward()
|
||||||
|
model.before_step(0)
|
||||||
torch.save(model.state_dict(), 'test_out.pth')
|
torch.save(model.state_dict(), 'test_out.pth')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user