forked from mrq/DL-Art-School
Fixes and mods to support training classifiers on imagenet
This commit is contained in:
parent
f129eaa39e
commit
45bc76ba92
|
@ -4,6 +4,9 @@ import torchvision.transforms as T
|
|||
from torchvision import datasets
|
||||
|
||||
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class TorchDataset(Dataset):
|
||||
def __init__(self, opt):
|
||||
DATASET_MAP = {
|
||||
|
@ -14,7 +17,7 @@ class TorchDataset(Dataset):
|
|||
"imagefolder": datasets.ImageFolder
|
||||
}
|
||||
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
if opt['train']:
|
||||
if opt_get(opt, ['random_crop'], False):
|
||||
transforms = [
|
||||
T.RandomResizedCrop(opt['image_size']),
|
||||
T.RandomHorizontalFlip(),
|
||||
|
@ -23,8 +26,9 @@ class TorchDataset(Dataset):
|
|||
]
|
||||
else:
|
||||
transforms = [
|
||||
T.Resize(opt['val_resize']),
|
||||
T.Resize(opt['image_size']),
|
||||
T.CenterCrop(opt['image_size']),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
|
|
|
@ -35,6 +35,7 @@ class ZipFileDataset(torch.utils.data.Dataset):
|
|||
return tensor
|
||||
|
||||
def __getitem__(self, i):
|
||||
try:
|
||||
fname = self.all_files[i]
|
||||
out = {
|
||||
'hq': self.load_image(fname),
|
||||
|
@ -47,6 +48,9 @@ class ZipFileDataset(torch.utils.data.Dataset):
|
|||
else:
|
||||
aname = fname.replace('1.jpg', '0.jpg')
|
||||
out['alt_hq'] = self.load_image(aname)
|
||||
except:
|
||||
print(f"Error loading {fname} from zipfile. Attempting to recover by loading next element.")
|
||||
return self[i+1]
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -14,17 +14,19 @@ def main():
|
|||
split_img = False
|
||||
opt = {}
|
||||
opt['n_thread'] = 5
|
||||
opt['compression_level'] = 95 # JPEG compression quality rating.
|
||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
|
||||
opt['dest'] = 'file'
|
||||
opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats']
|
||||
opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\256_4_by_3'
|
||||
opt['imgsize'] = (256,192)
|
||||
opt['input_folder'] = ['E:\\4k6k\datasets\\ns_images\\imagesets\\imageset_256_masked']
|
||||
opt['save_folder'] = 'E:\\4k6k\datasets\\ns_images\\imagesets\\imageset_128_masked'
|
||||
opt['imgsize'] = (128,128)
|
||||
opt['bottom_crop'] = 0
|
||||
opt['keep_folder'] = False
|
||||
|
||||
#opt['format'] = 'jpg'
|
||||
#opt['cv2_write_options'] = [cv2.IMWRITE_JPEG_QUALITY, 95]
|
||||
opt['format'] = 'png'
|
||||
opt['cv2_write_options'] = [cv2.IMWRITE_PNG_COMPRESSION, 9]
|
||||
|
||||
save_folder = opt['save_folder']
|
||||
if not osp.exists(save_folder):
|
||||
os.makedirs(save_folder)
|
||||
|
@ -93,7 +95,7 @@ class TiledDataset(data.Dataset):
|
|||
pts = os.path.split(pts[0])
|
||||
output_folder = osp.join(self.opt['save_folder'], pts[-1])
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
cv2.imwrite(osp.join(output_folder, basename.replace('.webp', '.jpg')), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||
cv2.imwrite(osp.join(output_folder, basename.replace('.webp', self.opt['format'])), img, self.opt['cv2_write_options'])
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -295,7 +295,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_imagenet_yt.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_resnet50_yt_pretrained.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -371,7 +371,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
if load_path is not None:
|
||||
if self.rank <= 0:
|
||||
logger.info('Loading model for [%s]' % (load_path,))
|
||||
self.load_network(load_path, net, self.opt['path']['strict_load'])
|
||||
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
||||
if hasattr(net.module, 'network_loaded'):
|
||||
net.module.network_loaded()
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ class BaseModel():
|
|||
save_path, os.path.join(self.opt['remote_path'], 'models', save_filename))
|
||||
return save_path
|
||||
|
||||
def load_network(self, load_path, network, strict=True):
|
||||
def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
|
||||
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||
network = network.module
|
||||
load_net = torch.load(load_path)
|
||||
|
@ -105,9 +105,15 @@ class BaseModel():
|
|||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||
if 'state_dict' in load_net:
|
||||
load_net = load_net['state_dict']
|
||||
|
||||
is_srflow = False
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
|
||||
if pretrain_base_path is not None:
|
||||
t = load_net
|
||||
load_net = {}
|
||||
for k, v in t.items():
|
||||
if k.startswith(pretrain_base_path):
|
||||
load_net[k[len(pretrain_base_path):]] = v
|
||||
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k.replace('module.', '')] = v
|
||||
|
|
|
@ -18,7 +18,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
|
|||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
self.dataset = torchvision.datasets.ImageFolder(
|
||||
'F:\\4k6k\\datasets\\images\\imagenet_2017\\val',
|
||||
'E:\\4k6k\\datasets\\images\\imagenet_2017\\val',
|
||||
transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
|
@ -27,7 +27,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
|
|||
]))
|
||||
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
|
||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
self.masking = opt_get(opt_eval, ['masking'], True)
|
||||
self.masking = opt_get(opt_eval, ['masking'], False)
|
||||
if self.masking:
|
||||
self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth',
|
||||
kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',
|
||||
|
|
Loading…
Reference in New Issue
Block a user