rename perplexity->log perplexity

This commit is contained in:
James Betker 2022-07-28 09:48:40 -06:00
parent 1d68624828
commit 27a9b1b750
2 changed files with 4 additions and 4 deletions

View File

@ -659,7 +659,7 @@ class GaussianDiffusion:
img = model_driven_out + guidance_driven_out img = model_driven_out + guidance_driven_out
return img return img
def p_sample_loop_for_perplexity( def p_sample_loop_for_log_perplexity(
self, self,
model, model,
truth, truth,

View File

@ -137,7 +137,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
# s = q9.clamp(1, 9999999999) # s = q9.clamp(1, 9999999999)
# x = x.clamp(-s, s) / s # x = x.clamp(-s, s) / s
# return x # return x
perp = self.diffuser.p_sample_loop_for_perplexity(self.model, mel_norm, perp = self.diffuser.p_sample_loop_for_log_perplexity(self.model, mel_norm,
model_kwargs = {'truth_mel': mel_norm}) model_kwargs = {'truth_mel': mel_norm})
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
@ -317,7 +317,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.local_modules[k] = mod.cpu() self.local_modules[k] = mod.cpu()
self.spec_decoder = self.spec_decoder.cpu() self.spec_decoder = self.spec_decoder.cpu()
return {"frechet_distance": frechet_distance, "perplexity": perplexity} return {"frechet_distance": frechet_distance, "log_perplexity": perplexity}
if __name__ == '__main__': if __name__ == '__main__':