diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen2.py b/codes/models/audio/music/unet_diffusion_waveform_gen2.py index 8e61a737..1e334b46 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen2.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen2.py @@ -395,8 +395,6 @@ class DiffusionTts(nn.Module): def get_grad_norm_parameter_groups(self): - if self.freeze_main_net: - return {} groups = { 'input_blocks': list(self.input_blocks.parameters()), 'output_blocks': list(self.output_blocks.parameters()),