forked from mrq/DL-Art-School
Fix cat eval hack
This commit is contained in:
parent
9b5f4abb91
commit
aea12e1b9c
|
@ -56,12 +56,11 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
|
|||
hq, labels = batch['hq'], batch['labels']
|
||||
hq = hq.to(self.env['device'])
|
||||
labels = labels.to(self.env['device'])
|
||||
coarse_labels = batch['coarse_labels'].to(self.env['device']) # Hack, remove this in the future.
|
||||
if self.masking:
|
||||
masks = self.mask_producer(hq)
|
||||
logits = self.model(hq, masks)
|
||||
else:
|
||||
logits = self.model(hq, coarse_labels)
|
||||
logits = self.model(hq)
|
||||
if not isinstance(logits, list) and not isinstance(logits, tuple):
|
||||
logits = [logits]
|
||||
logits = logits[self.gen_output_index]
|
||||
|
|
Loading…
Reference in New Issue
Block a user