fix compatibility

This commit is contained in:
James Betker 2022-07-13 21:28:20 -06:00
parent def70cd444
commit 4d5688be47

View File

@ -457,13 +457,16 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
p.grad *= .2
class TransformerDiffusionWithCheaterLatent(nn.Module):
def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs):
def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, res_encoder=False, **kwargs):
super().__init__()
self.internal_step = 0
self.freeze_encoder_until = freeze_encoder_until
self.diff = TransformerDiffusion(**kwargs)
from models.audio.music.encoders import ResEncoder16x
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
if res_encoder:
from models.audio.music.encoders import ResEncoder16x
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
else:
self.encoder = UpperEncoder(256, 1024, 256).eval()
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
unused_parameters = []