Fix code logging

This commit is contained in:
James Betker 2021-11-18 00:34:37 -07:00
parent f36bab95dd
commit f3db41f125

View File

@ -147,6 +147,7 @@ class DiscreteVAE(nn.Module):
if record_codes:
self.codes = torch.zeros((1228800,), dtype=torch.long)
self.code_ind = 0
self.total_codes = 0
self.internal_step = 0
def norm(self, images):
@ -163,7 +164,7 @@ class DiscreteVAE(nn.Module):
def get_debug_values(self, step, __):
if self.record_codes:
# Report annealing schedule
return {'histogram_codes': self.codes}
return {'histogram_codes': self.codes[:self.total_codes]}
else:
return {}
@ -243,6 +244,7 @@ class DiscreteVAE(nn.Module):
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
self.code_ind = 0
self.total_codes += 1
self.internal_step += 1