From d98b8953074ce2bc4b9295ae6f8e88267605082e Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 9 Jun 2022 21:56:47 -0600 Subject: [PATCH] loss aware fix and report gumbel temperature --- codes/models/audio/music/transformer_diffusion8.py | 3 ++- codes/trainer/injectors/gaussian_diffusion_injector.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 2acb3f7f..d22130ad 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -236,7 +236,8 @@ class TransformerDiffusionWithQuantizer(nn.Module): def get_debug_values(self, step, __): if self.quantizer.total_codes > 0: - return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes]} + return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes], + 'gumbel_temperature': self.quantizer.quantizer.temperature} else: return {} diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 0e35c361..0f1e1a54 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -71,7 +71,7 @@ class GaussianDiffusionInjector(Injector): t, weights = sampler.sample(hq.shape[0], hq.device) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn) if isinstance(sampler, LossAwareSampler): - sampler.update_with_local_losses(t, diffusion_outputs['losses']) + sampler.update_with_local_losses(t, diffusion_outputs['loss']) if len(self.extra_model_output_keys) > 0: assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs'])) out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}