From 45bc76ba92d94f616f4cfca3ba7cf1406ea16ba7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 1 Jun 2021 17:25:24 -0600 Subject: [PATCH] Fixes and mods to support training classifiers on imagenet --- codes/data/torch_dataset.py | 8 ++++-- codes/data/zip_file_dataset.py | 28 +++++++++++-------- codes/scripts/extract_square_images.py | 16 ++++++----- codes/train.py | 2 +- codes/trainer/ExtensibleTrainer.py | 2 +- codes/trainer/base_model.py | 12 ++++++-- .../trainer/eval/categorization_loss_eval.py | 4 +-- 7 files changed, 44 insertions(+), 28 deletions(-) diff --git a/codes/data/torch_dataset.py b/codes/data/torch_dataset.py index 7015eef9..920a1392 100644 --- a/codes/data/torch_dataset.py +++ b/codes/data/torch_dataset.py @@ -4,6 +4,9 @@ import torchvision.transforms as T from torchvision import datasets # Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer. +from utils.util import opt_get + + class TorchDataset(Dataset): def __init__(self, opt): DATASET_MAP = { @@ -14,7 +17,7 @@ class TorchDataset(Dataset): "imagefolder": datasets.ImageFolder } normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - if opt['train']: + if opt_get(opt, ['random_crop'], False): transforms = [ T.RandomResizedCrop(opt['image_size']), T.RandomHorizontalFlip(), @@ -23,8 +26,9 @@ class TorchDataset(Dataset): ] else: transforms = [ - T.Resize(opt['val_resize']), + T.Resize(opt['image_size']), T.CenterCrop(opt['image_size']), + T.RandomHorizontalFlip(), T.ToTensor(), normalize, ] diff --git a/codes/data/zip_file_dataset.py b/codes/data/zip_file_dataset.py index 513391c3..ef9ae50a 100644 --- a/codes/data/zip_file_dataset.py +++ b/codes/data/zip_file_dataset.py @@ -35,18 +35,22 @@ class ZipFileDataset(torch.utils.data.Dataset): return tensor def __getitem__(self, i): - fname = self.all_files[i] - out = { - 'hq': self.load_image(fname), - 'HQ_path': fname, - 'has_alt': self.paired_mode - } - if self.paired_mode: - if fname.endswith('0.jpg'): - aname = fname.replace('0.jpg', '1.jpg') - else: - aname = fname.replace('1.jpg', '0.jpg') - out['alt_hq'] = self.load_image(aname) + try: + fname = self.all_files[i] + out = { + 'hq': self.load_image(fname), + 'HQ_path': fname, + 'has_alt': self.paired_mode + } + if self.paired_mode: + if fname.endswith('0.jpg'): + aname = fname.replace('0.jpg', '1.jpg') + else: + aname = fname.replace('1.jpg', '0.jpg') + out['alt_hq'] = self.load_image(aname) + except: + print(f"Error loading {fname} from zipfile. Attempting to recover by loading next element.") + return self[i+1] return out if __name__ == '__main__': diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py index 0eeae7ed..a8f9eff8 100644 --- a/codes/scripts/extract_square_images.py +++ b/codes/scripts/extract_square_images.py @@ -14,17 +14,19 @@ def main(): split_img = False opt = {} opt['n_thread'] = 5 - opt['compression_level'] = 95 # JPEG compression quality rating. - # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer - # compression time. If read raw images during training, use 0 for faster IO speed. opt['dest'] = 'file' - opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats'] - opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\256_4_by_3' - opt['imgsize'] = (256,192) + opt['input_folder'] = ['E:\\4k6k\datasets\\ns_images\\imagesets\\imageset_256_masked'] + opt['save_folder'] = 'E:\\4k6k\datasets\\ns_images\\imagesets\\imageset_128_masked' + opt['imgsize'] = (128,128) opt['bottom_crop'] = 0 opt['keep_folder'] = False + #opt['format'] = 'jpg' + #opt['cv2_write_options'] = [cv2.IMWRITE_JPEG_QUALITY, 95] + opt['format'] = 'png' + opt['cv2_write_options'] = [cv2.IMWRITE_PNG_COMPRESSION, 9] + save_folder = opt['save_folder'] if not osp.exists(save_folder): os.makedirs(save_folder) @@ -93,7 +95,7 @@ class TiledDataset(data.Dataset): pts = os.path.split(pts[0]) output_folder = osp.join(self.opt['save_folder'], pts[-1]) os.makedirs(output_folder, exist_ok=True) - cv2.imwrite(osp.join(output_folder, basename.replace('.webp', '.jpg')), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) + cv2.imwrite(osp.join(output_folder, basename.replace('.webp', self.opt['format'])), img, self.opt['cv2_write_options']) return None def __len__(self): diff --git a/codes/train.py b/codes/train.py index d61576ec..d55728d6 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_imagenet_yt.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_resnet50_yt_pretrained.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 946258e9..bc79e840 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -371,7 +371,7 @@ class ExtensibleTrainer(BaseModel): if load_path is not None: if self.rank <= 0: logger.info('Loading model for [%s]' % (load_path,)) - self.load_network(load_path, net, self.opt['path']['strict_load']) + self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}'])) if hasattr(net.module, 'network_loaded'): net.module.network_loaded() diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 4d0a6ede..6a1f3ceb 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -97,7 +97,7 @@ class BaseModel(): save_path, os.path.join(self.opt['remote_path'], 'models', save_filename)) return save_path - def load_network(self, load_path, network, strict=True): + def load_network(self, load_path, network, strict=True, pretrain_base_path=None): #if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): network = network.module load_net = torch.load(load_path) @@ -105,9 +105,15 @@ class BaseModel(): # Support loading torch.save()s for whole models as well as just state_dicts. if 'state_dict' in load_net: load_net = load_net['state_dict'] - - is_srflow = False load_net_clean = OrderedDict() # remove unnecessary 'module.' + + if pretrain_base_path is not None: + t = load_net + load_net = {} + for k, v in t.items(): + if k.startswith(pretrain_base_path): + load_net[k[len(pretrain_base_path):]] = v + for k, v in load_net.items(): if k.startswith('module.'): load_net_clean[k.replace('module.', '')] = v diff --git a/codes/trainer/eval/categorization_loss_eval.py b/codes/trainer/eval/categorization_loss_eval.py index 57fb33ea..6f3bb6d0 100644 --- a/codes/trainer/eval/categorization_loss_eval.py +++ b/codes/trainer/eval/categorization_loss_eval.py @@ -18,7 +18,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.dataset = torchvision.datasets.ImageFolder( - 'F:\\4k6k\\datasets\\images\\imagenet_2017\\val', + 'E:\\4k6k\\datasets\\images\\imagenet_2017\\val', transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), @@ -27,7 +27,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator): ])) self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4) self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 - self.masking = opt_get(opt_eval, ['masking'], True) + self.masking = opt_get(opt_eval, ['masking'], False) if self.masking: self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth', kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',