diff --git a/codes/trainer/eval/categorization_loss_eval.py b/codes/trainer/eval/categorization_loss_eval.py index cca6eb1d..fcd009e3 100644 --- a/codes/trainer/eval/categorization_loss_eval.py +++ b/codes/trainer/eval/categorization_loss_eval.py @@ -56,12 +56,11 @@ class CategorizationLossEvaluator(evaluator.Evaluator): hq, labels = batch['hq'], batch['labels'] hq = hq.to(self.env['device']) labels = labels.to(self.env['device']) - coarse_labels = batch['coarse_labels'].to(self.env['device']) # Hack, remove this in the future. if self.masking: masks = self.mask_producer(hq) logits = self.model(hq, masks) else: - logits = self.model(hq, coarse_labels) + logits = self.model(hq) if not isinstance(logits, list) and not isinstance(logits, tuple): logits = [logits] logits = logits[self.gen_output_index]