@ -312,7 +312,7 @@ class DiffusionWaveformGen(nn.Module):
groups = {
'input_blocks': list(self.input_blocks.parameters()),
'output_blocks': list(self.output_blocks.parameters()),
'middle_transformer': list(self.middle_block.parameters()),
'middle_rrdb': list(self.middle_block.parameters()),
}
return groups