Fix cat eval hack

This commit is contained in:
James Betker 2021-06-09 17:05:11 -06:00
parent 9b5f4abb91
commit aea12e1b9c

View File

@ -56,12 +56,11 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
hq, labels = batch['hq'], batch['labels'] hq, labels = batch['hq'], batch['labels']
hq = hq.to(self.env['device']) hq = hq.to(self.env['device'])
labels = labels.to(self.env['device']) labels = labels.to(self.env['device'])
coarse_labels = batch['coarse_labels'].to(self.env['device']) # Hack, remove this in the future.
if self.masking: if self.masking:
masks = self.mask_producer(hq) masks = self.mask_producer(hq)
logits = self.model(hq, masks) logits = self.model(hq, masks)
else: else:
logits = self.model(hq, coarse_labels) logits = self.model(hq)
if not isinstance(logits, list) and not isinstance(logits, tuple): if not isinstance(logits, list) and not isinstance(logits, tuple):
logits = [logits] logits = [logits]
logits = logits[self.gen_output_index] logits = logits[self.gen_output_index]