forked from mrq/DL-Art-School
fix
This commit is contained in:
parent
1ac02acdc3
commit
64b6ae2f4a
|
@ -208,11 +208,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
||||||
del self.m2v.up
|
del self.m2v.up
|
||||||
|
|
||||||
self.codes = torch.zeros((3000000,), dtype=torch.long)
|
|
||||||
self.internal_step = 0
|
|
||||||
self.code_ind = 0
|
|
||||||
self.total_codes = 0
|
|
||||||
|
|
||||||
def update_for_step(self, step, *args):
|
def update_for_step(self, step, *args):
|
||||||
self.internal_step = step
|
self.internal_step = step
|
||||||
self.m2v.quantizer.temperature = max(
|
self.m2v.quantizer.temperature = max(
|
||||||
|
@ -226,7 +221,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
conditioning_free=conditioning_free)
|
conditioning_free=conditioning_free)
|
||||||
|
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
if self.total_codes > 0:
|
if self.m2v.total_codes > 0:
|
||||||
return {'histogram_codes': self.m2v.codes[:self.m2v.total_codes]}
|
return {'histogram_codes': self.m2v.codes[:self.m2v.total_codes]}
|
||||||
else:
|
else:
|
||||||
return {}
|
return {}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user