From aea12e1b9c7db3888990232ffd359a8b9a45138a Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 9 Jun 2021 17:05:11 -0600 Subject: [PATCH] Fix cat eval hack --- codes/trainer/eval/categorization_loss_eval.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/codes/trainer/eval/categorization_loss_eval.py b/codes/trainer/eval/categorization_loss_eval.py index cca6eb1d..fcd009e3 100644 --- a/codes/trainer/eval/categorization_loss_eval.py +++ b/codes/trainer/eval/categorization_loss_eval.py @@ -56,12 +56,11 @@ class CategorizationLossEvaluator(evaluator.Evaluator): hq, labels = batch['hq'], batch['labels'] hq = hq.to(self.env['device']) labels = labels.to(self.env['device']) - coarse_labels = batch['coarse_labels'].to(self.env['device']) # Hack, remove this in the future. if self.masking: masks = self.mask_producer(hq) logits = self.model(hq, masks) else: - logits = self.model(hq, coarse_labels) + logits = self.model(hq) if not isinstance(logits, list) and not isinstance(logits, tuple): logits = [logits] logits = logits[self.gen_output_index]