diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen3.py b/codes/models/audio/music/unet_diffusion_waveform_gen3.py index 857a849e..3610f882 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen3.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen3.py @@ -312,7 +312,7 @@ class DiffusionWaveformGen(nn.Module): groups = { 'input_blocks': list(self.input_blocks.parameters()), 'output_blocks': list(self.output_blocks.parameters()), - 'middle_transformer': list(self.middle_block.parameters()), + 'middle_rrdb': list(self.middle_block.parameters()), } return groups