fix unused parameters

This commit is contained in:
James Betker 2022-05-30 16:31:40 -06:00
parent f7d237a50a
commit 71cf654957

View File

@ -214,6 +214,10 @@ class TransformerDiffusionWithQuantizer(nn.Module):
self.total_codes = 0
del self.m2v.m2v.encoder
del self.m2v.reconstruction_net
del self.m2v.m2v.projector.projection
del self.m2v.project_hid
del self.m2v.project_q
def update_for_step(self, step, *args):
self.internal_step = step
@ -224,7 +228,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1)
_, proj = self.m2v.m2v.projector(proj)
proj = self.m2v.m2v.projector.layer_norm(proj)
vectors, _, probs = self.m2v.quantizer(proj, return_probs=True)
self.log_codes(probs)
return self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free)