DL-Art-School/codes/scripts/cifar100_untangle.py
2021-06-07 11:33:33 -06:00

42 lines
1.3 KiB
Python

import numpy
import torch
from torch.utils.data import DataLoader
from data.torch_dataset import TorchDataset
from models.classifiers.cifar_resnet_branched import ResNet
from models.classifiers.cifar_resnet_branched import BasicBlock
if __name__ == '__main__':
dopt = {
'flip': True,
'crop_sz': None,
'dataset': 'cifar100',
'image_size': 32,
'normalize': False,
'kwargs': {
'root': 'E:\\4k6k\\datasets\\images\\cifar100',
'download': True
}
}
set = TorchDataset(dopt)
loader = DataLoader(set, num_workers=0, batch_size=32)
model = ResNet(BasicBlock, [2, 2, 2, 2])
model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardsw_85000.pth'))
model.eval()
bins = [[] for _ in range(8)]
for i, batch in enumerate(loader):
logits, selector = model(batch['hq'], coarse_label=None, return_selector=True)
for k, s in enumerate(selector):
for j, b in enumerate(s):
if b:
bins[j].append(batch['labels'][k].item())
if i > 10:
break
import matplotlib.pyplot as plt
fig, axs = plt.subplots(3,3)
for i in range(8):
axs[i%3, i//3].hist(numpy.asarray(bins[i]))
plt.show()
print('hi')