diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index f750d96c..9ce7dcbd 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -457,13 +457,16 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module): p.grad *= .2 class TransformerDiffusionWithCheaterLatent(nn.Module): - def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs): + def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, res_encoder=False, **kwargs): super().__init__() self.internal_step = 0 self.freeze_encoder_until = freeze_encoder_until self.diff = TransformerDiffusion(**kwargs) - from models.audio.music.encoders import ResEncoder16x - self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder) + if res_encoder: + from models.audio.music.encoders import ResEncoder16x + self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder) + else: + self.encoder = UpperEncoder(256, 1024, 256).eval() def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): unused_parameters = []