forked from mrq/DL-Art-School
fix compatibility
This commit is contained in:
parent
def70cd444
commit
4d5688be47
|
@ -457,13 +457,16 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
|
||||||
p.grad *= .2
|
p.grad *= .2
|
||||||
|
|
||||||
class TransformerDiffusionWithCheaterLatent(nn.Module):
|
class TransformerDiffusionWithCheaterLatent(nn.Module):
|
||||||
def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs):
|
def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, res_encoder=False, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
self.freeze_encoder_until = freeze_encoder_until
|
self.freeze_encoder_until = freeze_encoder_until
|
||||||
self.diff = TransformerDiffusion(**kwargs)
|
self.diff = TransformerDiffusion(**kwargs)
|
||||||
|
if res_encoder:
|
||||||
from models.audio.music.encoders import ResEncoder16x
|
from models.audio.music.encoders import ResEncoder16x
|
||||||
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
||||||
|
else:
|
||||||
|
self.encoder = UpperEncoder(256, 1024, 256).eval()
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
||||||
unused_parameters = []
|
unused_parameters = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user