forked from mrq/DL-Art-School
debug gumbel temperature
This commit is contained in:
parent
91be38cba3
commit
43f225c35c
|
@ -26,7 +26,7 @@ class ConditioningEncoder(nn.Module):
|
||||||
self.dim = embedding_dim
|
self.dim = embedding_dim
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = checkpoint(self.init, x)
|
h = self.init(x)
|
||||||
h = self.attn(h)
|
h = self.attn(h)
|
||||||
return h.mean(dim=2)
|
return h.mean(dim=2)
|
||||||
|
|
||||||
|
@ -150,7 +150,8 @@ class GptMusicLower(nn.Module):
|
||||||
|
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
if self.upper_quantizer.total_codes > 0:
|
if self.upper_quantizer.total_codes > 0:
|
||||||
return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]}
|
return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes],
|
||||||
|
'gumbel_temperature': self.upper_quantizer.quantizer.temperature}
|
||||||
else:
|
else:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user