diff --git a/codes/trainer/eval/categorization_loss_eval.py b/codes/trainer/eval/categorization_loss_eval.py index fcd009e3..cca6eb1d 100644 --- a/codes/trainer/eval/categorization_loss_eval.py +++ b/codes/trainer/eval/categorization_loss_eval.py @@ -56,11 +56,12 @@ class CategorizationLossEvaluator(evaluator.Evaluator): hq, labels = batch['hq'], batch['labels'] hq = hq.to(self.env['device']) labels = labels.to(self.env['device']) + coarse_labels = batch['coarse_labels'].to(self.env['device']) # Hack, remove this in the future. if self.masking: masks = self.mask_producer(hq) logits = self.model(hq, masks) else: - logits = self.model(hq) + logits = self.model(hq, coarse_labels) if not isinstance(logits, list) and not isinstance(logits, tuple): logits = [logits] logits = logits[self.gen_output_index]