asdf
This commit is contained in:
parent
e6c537824a
commit
6c8c8087d5
|
@ -6,6 +6,7 @@ from torchvision import transforms
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import trainer.eval.evaluator as evaluator
|
import trainer.eval.evaluator as evaluator
|
||||||
|
from data import create_dataset
|
||||||
from models.vqvae.kmeans_mask_producer import UResnetMaskProducer
|
from models.vqvae.kmeans_mask_producer import UResnetMaskProducer
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
@ -17,14 +18,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
|
||||||
assert self.batch_sz is not None
|
assert self.batch_sz is not None
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225])
|
std=[0.229, 0.224, 0.225])
|
||||||
self.dataset = torchvision.datasets.ImageFolder(
|
self.dataset = create_dataset(opt_eval['dataset'])
|
||||||
'E:\\4k6k\\datasets\\images\\imagenet_2017\\val',
|
|
||||||
transforms.Compose([
|
|
||||||
transforms.Resize(256),
|
|
||||||
transforms.CenterCrop(224),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]))
|
|
||||||
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
|
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
|
||||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||||
self.masking = opt_get(opt_eval, ['masking'], False)
|
self.masking = opt_get(opt_eval, ['masking'], False)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user