From 4d5688be4763faebccbb8e5bab8b37d144e4cf20 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 13 Jul 2022 21:28:20 -0600 Subject: [PATCH] fix compatibility --- codes/models/audio/music/transformer_diffusion12.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index f750d96c..9ce7dcbd 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -457,13 +457,16 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module): p.grad *= .2 class TransformerDiffusionWithCheaterLatent(nn.Module): - def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs): + def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, res_encoder=False, **kwargs): super().__init__() self.internal_step = 0 self.freeze_encoder_until = freeze_encoder_until self.diff = TransformerDiffusion(**kwargs) - from models.audio.music.encoders import ResEncoder16x - self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder) + if res_encoder: + from models.audio.music.encoders import ResEncoder16x + self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder) + else: + self.encoder = UpperEncoder(256, 1024, 256).eval() def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): unused_parameters = []