DL-Art-School/codes/trainer/eval/categorization_loss_eval.py
2021-06-14 09:51:44 -06:00

93 lines
3.6 KiB
Python

import torch
import torchvision
from torch.nn.functional import interpolate
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import trainer.eval.evaluator as evaluator
from data import create_dataset
from models.vqvae.kmeans_mask_producer import UResnetMaskProducer
from utils.util import opt_get
class CategorizationLossEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.batch_sz = opt_eval['batch_size']
assert self.batch_sz is not None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.dataset = create_dataset(opt_eval['dataset'])
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
self.masking = opt_get(opt_eval, ['masking'], False)
if self.masking:
self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth',
kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',
mask_scales=[.03125, .0625, .125, .25, .5, 1.0],
tail_dim=256).to('cuda')
def accuracy(self, output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
def perform_eval(self):
counter = 0.0
ce_loss = 0.0
top_5_acc = 0.0
top_1_acc = 0.0
self.model.eval()
with torch.no_grad():
for batch in tqdm(self.dataloader):
hq, labels = batch['hq'], batch['labels']
hq = hq.to(self.env['device'])
labels = labels.to(self.env['device'])
if self.masking:
masks = self.mask_producer(hq)
logits = self.model(hq, masks)
else:
logits = self.model(hq)
if not isinstance(logits, list) and not isinstance(logits, tuple):
logits = [logits]
logits = logits[self.gen_output_index]
ce_loss += torch.nn.functional.cross_entropy(logits, labels).detach()
t1, t5 = self.accuracy(logits, labels, (1, 5))
top_1_acc += t1.detach()
top_5_acc += t5.detach()
counter += len(hq) / self.batch_sz
self.model.train()
return {"val_cross_entropy": ce_loss / counter,
"top_5_accuracy": top_5_acc / counter,
"top_1_accuracy": top_1_acc / counter }
if __name__ == '__main__':
from torchvision.models import resnet50
model = resnet50(pretrained=True).to('cuda')
opt = {
'batch_size': 128,
'gen_index': 0,
'masking': False
}
env = {
'device': 'cuda',
}
eval = CategorizationLossEvaluator(model, opt, env)
print(eval.perform_eval())