From 16cd92acd59a78829765b06f5a3595c92aaac1e2 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Jun 2021 14:23:41 -0600 Subject: [PATCH] hack --- codes/trainer/eval/categorization_loss_eval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codes/trainer/eval/categorization_loss_eval.py b/codes/trainer/eval/categorization_loss_eval.py index fcd009e3..cca6eb1d 100644 --- a/codes/trainer/eval/categorization_loss_eval.py +++ b/codes/trainer/eval/categorization_loss_eval.py @@ -56,11 +56,12 @@ class CategorizationLossEvaluator(evaluator.Evaluator): hq, labels = batch['hq'], batch['labels'] hq = hq.to(self.env['device']) labels = labels.to(self.env['device']) + coarse_labels = batch['coarse_labels'].to(self.env['device']) # Hack, remove this in the future. if self.masking: masks = self.mask_producer(hq) logits = self.model(hq, masks) else: - logits = self.model(hq) + logits = self.model(hq, coarse_labels) if not isinstance(logits, list) and not isinstance(logits, tuple): logits = [logits] logits = logits[self.gen_output_index]