This commit is contained in:
James Betker 2022-06-01 01:01:32 -06:00
parent 1ac02acdc3
commit 64b6ae2f4a

View File

@ -208,11 +208,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
del self.m2v.up
self.codes = torch.zeros((3000000,), dtype=torch.long)
self.internal_step = 0
self.code_ind = 0
self.total_codes = 0
def update_for_step(self, step, *args):
self.internal_step = step
self.m2v.quantizer.temperature = max(
@ -226,7 +221,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
conditioning_free=conditioning_free)
def get_debug_values(self, step, __):
if self.total_codes > 0:
if self.m2v.total_codes > 0:
return {'histogram_codes': self.m2v.codes[:self.m2v.total_codes]}
else:
return {}