diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 0abcfacb..372b7fc3 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -168,7 +168,7 @@ class TransformerDiffusion(nn.Module): for p in self.parameters(): p.DO_NOT_TRAIN = True p.requires_grad = False - for m in [self.input_converter and self.code_converter]: + for m in [self.ar_input and self.ar_prior_intg]: for p in m.parameters(): del p.DO_NOT_TRAIN p.requires_grad = True