From 43f225c35c02bcde459db2c89cdd2fe00537cc4d Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 8 Jun 2022 12:12:08 -0600 Subject: [PATCH] debug gumbel temperature --- codes/models/audio/music/gpt_music.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index 3b213974..515093b0 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -26,7 +26,7 @@ class ConditioningEncoder(nn.Module): self.dim = embedding_dim def forward(self, x): - h = checkpoint(self.init, x) + h = self.init(x) h = self.attn(h) return h.mean(dim=2) @@ -150,7 +150,8 @@ class GptMusicLower(nn.Module): def get_debug_values(self, step, __): if self.upper_quantizer.total_codes > 0: - return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]} + return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes], + 'gumbel_temperature': self.upper_quantizer.quantizer.temperature} else: return {}