DL-Art-School/codes/trainer/eval/categorization_loss_eval.py

93 lines
3.6 KiB
Python
Raw Normal View History

2021-01-12 03:09:16 +00:00
import torch
import torchvision
from torch.nn.functional import interpolate
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import trainer.eval.evaluator as evaluator
2021-06-05 03:24:48 +00:00
from data import create_dataset
2021-01-12 03:09:16 +00:00
from models.vqvae.kmeans_mask_producer import UResnetMaskProducer
from utils.util import opt_get
class CategorizationLossEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
self.batch_sz = opt_eval['batch_size']
assert self.batch_sz is not None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
2021-06-05 03:24:48 +00:00
self.dataset = create_dataset(opt_eval['dataset'])
2021-01-12 03:09:16 +00:00
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
self.masking = opt_get(opt_eval, ['masking'], False)
2021-01-12 03:09:16 +00:00
if self.masking:
self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth',
kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',
mask_scales=[.03125, .0625, .125, .25, .5, 1.0],
tail_dim=256).to('cuda')
def accuracy(self, output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
def perform_eval(self):
counter = 0.0
ce_loss = 0.0
top_5_acc = 0.0
top_1_acc = 0.0
self.model.eval()
with torch.no_grad():
2021-06-05 03:26:54 +00:00
for batch in tqdm(self.dataloader):
hq, labels = batch['hq'], batch['labels']
2021-01-12 03:09:16 +00:00
hq = hq.to(self.env['device'])
labels = labels.to(self.env['device'])
if self.masking:
masks = self.mask_producer(hq)
logits = self.model(hq, masks)
else:
logits = self.model(hq)
if not isinstance(logits, list) and not isinstance(logits, tuple):
logits = [logits]
logits = logits[self.gen_output_index]
ce_loss += torch.nn.functional.cross_entropy(logits, labels).detach()
t1, t5 = self.accuracy(logits, labels, (1, 5))
top_1_acc += t1.detach()
top_5_acc += t5.detach()
counter += len(hq) / self.batch_sz
self.model.train()
return {"val_cross_entropy": ce_loss / counter,
"top_5_accuracy": top_5_acc / counter,
"top_1_accuracy": top_1_acc / counter }
if __name__ == '__main__':
from torchvision.models import resnet50
model = resnet50(pretrained=True).to('cuda')
opt = {
'batch_size': 128,
'gen_index': 0,
'masking': False
}
env = {
'device': 'cuda',
}
eval = CategorizationLossEvaluator(model, opt, env)
print(eval.perform_eval())