fix
This commit is contained in:
parent
1ac02acdc3
commit
64b6ae2f4a
|
@ -208,11 +208,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
||||
del self.m2v.up
|
||||
|
||||
self.codes = torch.zeros((3000000,), dtype=torch.long)
|
||||
self.internal_step = 0
|
||||
self.code_ind = 0
|
||||
self.total_codes = 0
|
||||
|
||||
def update_for_step(self, step, *args):
|
||||
self.internal_step = step
|
||||
self.m2v.quantizer.temperature = max(
|
||||
|
@ -226,7 +221,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
conditioning_free=conditioning_free)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.total_codes > 0:
|
||||
if self.m2v.total_codes > 0:
|
||||
return {'histogram_codes': self.m2v.codes[:self.m2v.total_codes]}
|
||||
else:
|
||||
return {}
|
||||
|
|
Loading…
Reference in New Issue
Block a user