loss aware fix and report gumbel temperature
This commit is contained in:
parent
6e57eaa186
commit
d98b895307
|
@ -236,7 +236,8 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.quantizer.total_codes > 0:
|
||||
return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes]}
|
||||
return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes],
|
||||
'gumbel_temperature': self.quantizer.quantizer.temperature}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ class GaussianDiffusionInjector(Injector):
|
|||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
|
||||
if isinstance(sampler, LossAwareSampler):
|
||||
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||
sampler.update_with_local_losses(t, diffusion_outputs['loss'])
|
||||
if len(self.extra_model_output_keys) > 0:
|
||||
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
||||
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
|
||||
|
|
Loading…
Reference in New Issue
Block a user