loss aware fix and report gumbel temperature

This commit is contained in:
James Betker 2022-06-09 21:56:47 -06:00
parent 6e57eaa186
commit d98b895307
2 changed files with 3 additions and 2 deletions

View File

@ -236,7 +236,8 @@ class TransformerDiffusionWithQuantizer(nn.Module):
def get_debug_values(self, step, __):
if self.quantizer.total_codes > 0:
return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes]}
return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes],
'gumbel_temperature': self.quantizer.quantizer.temperature}
else:
return {}

View File

@ -71,7 +71,7 @@ class GaussianDiffusionInjector(Injector):
t, weights = sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
if isinstance(sampler, LossAwareSampler):
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
sampler.update_with_local_losses(t, diffusion_outputs['loss'])
if len(self.extra_model_output_keys) > 0:
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}