debug gumbel temperature

This commit is contained in:
James Betker 2022-06-08 12:12:08 -06:00
parent 91be38cba3
commit 43f225c35c

View File

@ -26,7 +26,7 @@ class ConditioningEncoder(nn.Module):
self.dim = embedding_dim
def forward(self, x):
h = checkpoint(self.init, x)
h = self.init(x)
h = self.attn(h)
return h.mean(dim=2)
@ -150,7 +150,8 @@ class GptMusicLower(nn.Module):
def get_debug_values(self, step, __):
if self.upper_quantizer.total_codes > 0:
return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]}
return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes],
'gumbel_temperature': self.upper_quantizer.quantizer.temperature}
else:
return {}