Fix cat eval hack

This commit is contained in:
James Betker 2021-06-09 17:05:11 -06:00
parent 9b5f4abb91
commit aea12e1b9c

View File

@ -56,12 +56,11 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
hq, labels = batch['hq'], batch['labels']
hq = hq.to(self.env['device'])
labels = labels.to(self.env['device'])
coarse_labels = batch['coarse_labels'].to(self.env['device']) # Hack, remove this in the future.
if self.masking:
masks = self.mask_producer(hq)
logits = self.model(hq, masks)
else:
logits = self.model(hq, coarse_labels)
logits = self.model(hq)
if not isinstance(logits, list) and not isinstance(logits, tuple):
logits = [logits]
logits = logits[self.gen_output_index]