forked from mrq/DL-Art-School
Remove broken evaluator
This commit is contained in:
parent
46b97049dc
commit
f6a7f12cad
|
@ -1,92 +0,0 @@
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
from torch.nn.functional import interpolate
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import trainer.eval.evaluator as evaluator
|
|
||||||
from data import create_dataset
|
|
||||||
from models.vqvae.kmeans_mask_producer import UResnetMaskProducer
|
|
||||||
from utils.util import opt_get
|
|
||||||
|
|
||||||
|
|
||||||
class CategorizationLossEvaluator(evaluator.Evaluator):
|
|
||||||
def __init__(self, model, opt_eval, env):
|
|
||||||
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
|
||||||
self.batch_sz = opt_eval['batch_size']
|
|
||||||
assert self.batch_sz is not None
|
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
||||||
std=[0.229, 0.224, 0.225])
|
|
||||||
self.dataset = create_dataset(opt_eval['dataset'])
|
|
||||||
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
|
|
||||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
|
||||||
self.masking = opt_get(opt_eval, ['masking'], False)
|
|
||||||
if self.masking:
|
|
||||||
self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth',
|
|
||||||
kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',
|
|
||||||
mask_scales=[.03125, .0625, .125, .25, .5, 1.0],
|
|
||||||
tail_dim=256).to('cuda')
|
|
||||||
|
|
||||||
def accuracy(self, output, target, topk=(1,)):
|
|
||||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
||||||
with torch.no_grad():
|
|
||||||
maxk = max(topk)
|
|
||||||
batch_size = target.size(0)
|
|
||||||
|
|
||||||
_, pred = output.topk(maxk, 1, True, True)
|
|
||||||
pred = pred.t()
|
|
||||||
correct = pred.eq(target[None])
|
|
||||||
|
|
||||||
res = []
|
|
||||||
for k in topk:
|
|
||||||
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
|
||||||
res.append(correct_k * (100.0 / batch_size))
|
|
||||||
return res
|
|
||||||
|
|
||||||
def perform_eval(self):
|
|
||||||
counter = 0.0
|
|
||||||
ce_loss = 0.0
|
|
||||||
top_5_acc = 0.0
|
|
||||||
top_1_acc = 0.0
|
|
||||||
|
|
||||||
self.model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in tqdm(self.dataloader):
|
|
||||||
hq, labels = batch['hq'], batch['labels']
|
|
||||||
hq = hq.to(self.env['device'])
|
|
||||||
labels = labels.to(self.env['device'])
|
|
||||||
if self.masking:
|
|
||||||
masks = self.mask_producer(hq)
|
|
||||||
logits = self.model(hq, masks)
|
|
||||||
else:
|
|
||||||
logits = self.model(hq)
|
|
||||||
if not isinstance(logits, list) and not isinstance(logits, tuple):
|
|
||||||
logits = [logits]
|
|
||||||
logits = logits[self.gen_output_index]
|
|
||||||
ce_loss += torch.nn.functional.cross_entropy(logits, labels).detach()
|
|
||||||
t1, t5 = self.accuracy(logits, labels, (1, 5))
|
|
||||||
top_1_acc += t1.detach()
|
|
||||||
top_5_acc += t5.detach()
|
|
||||||
counter += len(hq) / self.batch_sz
|
|
||||||
self.model.train()
|
|
||||||
|
|
||||||
return {"val_cross_entropy": ce_loss / counter,
|
|
||||||
"top_5_accuracy": top_5_acc / counter,
|
|
||||||
"top_1_accuracy": top_1_acc / counter }
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from torchvision.models import resnet50
|
|
||||||
model = resnet50(pretrained=True).to('cuda')
|
|
||||||
opt = {
|
|
||||||
'batch_size': 128,
|
|
||||||
'gen_index': 0,
|
|
||||||
'masking': False
|
|
||||||
}
|
|
||||||
env = {
|
|
||||||
'device': 'cuda',
|
|
||||||
|
|
||||||
}
|
|
||||||
eval = CategorizationLossEvaluator(model, opt, env)
|
|
||||||
print(eval.perform_eval())
|
|
Loading…
Reference in New Issue
Block a user