forked from mrq/DL-Art-School
Fix code logging
This commit is contained in:
parent
f36bab95dd
commit
f3db41f125
|
@ -147,6 +147,7 @@ class DiscreteVAE(nn.Module):
|
||||||
if record_codes:
|
if record_codes:
|
||||||
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
||||||
self.code_ind = 0
|
self.code_ind = 0
|
||||||
|
self.total_codes = 0
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
|
|
||||||
def norm(self, images):
|
def norm(self, images):
|
||||||
|
@ -163,7 +164,7 @@ class DiscreteVAE(nn.Module):
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
if self.record_codes:
|
if self.record_codes:
|
||||||
# Report annealing schedule
|
# Report annealing schedule
|
||||||
return {'histogram_codes': self.codes}
|
return {'histogram_codes': self.codes[:self.total_codes]}
|
||||||
else:
|
else:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -243,6 +244,7 @@ class DiscreteVAE(nn.Module):
|
||||||
self.code_ind = self.code_ind + l
|
self.code_ind = self.code_ind + l
|
||||||
if self.code_ind >= self.codes.shape[0]:
|
if self.code_ind >= self.codes.shape[0]:
|
||||||
self.code_ind = 0
|
self.code_ind = 0
|
||||||
|
self.total_codes += 1
|
||||||
self.internal_step += 1
|
self.internal_step += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user