This commit is contained in:
James Betker 2021-06-05 14:23:41 -06:00
parent af52751d6b
commit 16cd92acd5

View File

@ -56,11 +56,12 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
hq, labels = batch['hq'], batch['labels']
hq = hq.to(self.env['device'])
labels = labels.to(self.env['device'])
coarse_labels = batch['coarse_labels'].to(self.env['device']) # Hack, remove this in the future.
if self.masking:
masks = self.mask_producer(hq)
logits = self.model(hq, masks)
else:
logits = self.model(hq)
logits = self.model(hq, coarse_labels)
if not isinstance(logits, list) and not isinstance(logits, tuple):
logits = [logits]
logits = logits[self.gen_output_index]