Fixes and mods to support training classifiers on imagenet

This commit is contained in:
James Betker 2021-06-01 17:25:24 -06:00
parent f129eaa39e
commit 45bc76ba92
7 changed files with 44 additions and 28 deletions

View File

@ -4,6 +4,9 @@ import torchvision.transforms as T
from torchvision import datasets from torchvision import datasets
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer. # Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
from utils.util import opt_get
class TorchDataset(Dataset): class TorchDataset(Dataset):
def __init__(self, opt): def __init__(self, opt):
DATASET_MAP = { DATASET_MAP = {
@ -14,7 +17,7 @@ class TorchDataset(Dataset):
"imagefolder": datasets.ImageFolder "imagefolder": datasets.ImageFolder
} }
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if opt['train']: if opt_get(opt, ['random_crop'], False):
transforms = [ transforms = [
T.RandomResizedCrop(opt['image_size']), T.RandomResizedCrop(opt['image_size']),
T.RandomHorizontalFlip(), T.RandomHorizontalFlip(),
@ -23,8 +26,9 @@ class TorchDataset(Dataset):
] ]
else: else:
transforms = [ transforms = [
T.Resize(opt['val_resize']), T.Resize(opt['image_size']),
T.CenterCrop(opt['image_size']), T.CenterCrop(opt['image_size']),
T.RandomHorizontalFlip(),
T.ToTensor(), T.ToTensor(),
normalize, normalize,
] ]

View File

@ -35,18 +35,22 @@ class ZipFileDataset(torch.utils.data.Dataset):
return tensor return tensor
def __getitem__(self, i): def __getitem__(self, i):
fname = self.all_files[i] try:
out = { fname = self.all_files[i]
'hq': self.load_image(fname), out = {
'HQ_path': fname, 'hq': self.load_image(fname),
'has_alt': self.paired_mode 'HQ_path': fname,
} 'has_alt': self.paired_mode
if self.paired_mode: }
if fname.endswith('0.jpg'): if self.paired_mode:
aname = fname.replace('0.jpg', '1.jpg') if fname.endswith('0.jpg'):
else: aname = fname.replace('0.jpg', '1.jpg')
aname = fname.replace('1.jpg', '0.jpg') else:
out['alt_hq'] = self.load_image(aname) aname = fname.replace('1.jpg', '0.jpg')
out['alt_hq'] = self.load_image(aname)
except:
print(f"Error loading {fname} from zipfile. Attempting to recover by loading next element.")
return self[i+1]
return out return out
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -14,17 +14,19 @@ def main():
split_img = False split_img = False
opt = {} opt = {}
opt['n_thread'] = 5 opt['n_thread'] = 5
opt['compression_level'] = 95 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file' opt['dest'] = 'file'
opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats'] opt['input_folder'] = ['E:\\4k6k\datasets\\ns_images\\imagesets\\imageset_256_masked']
opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\256_4_by_3' opt['save_folder'] = 'E:\\4k6k\datasets\\ns_images\\imagesets\\imageset_128_masked'
opt['imgsize'] = (256,192) opt['imgsize'] = (128,128)
opt['bottom_crop'] = 0 opt['bottom_crop'] = 0
opt['keep_folder'] = False opt['keep_folder'] = False
#opt['format'] = 'jpg'
#opt['cv2_write_options'] = [cv2.IMWRITE_JPEG_QUALITY, 95]
opt['format'] = 'png'
opt['cv2_write_options'] = [cv2.IMWRITE_PNG_COMPRESSION, 9]
save_folder = opt['save_folder'] save_folder = opt['save_folder']
if not osp.exists(save_folder): if not osp.exists(save_folder):
os.makedirs(save_folder) os.makedirs(save_folder)
@ -93,7 +95,7 @@ class TiledDataset(data.Dataset):
pts = os.path.split(pts[0]) pts = os.path.split(pts[0])
output_folder = osp.join(self.opt['save_folder'], pts[-1]) output_folder = osp.join(self.opt['save_folder'], pts[-1])
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
cv2.imwrite(osp.join(output_folder, basename.replace('.webp', '.jpg')), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) cv2.imwrite(osp.join(output_folder, basename.replace('.webp', self.opt['format'])), img, self.opt['cv2_write_options'])
return None return None
def __len__(self): def __len__(self):

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_imagenet_yt.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_resnet50_yt_pretrained.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -371,7 +371,7 @@ class ExtensibleTrainer(BaseModel):
if load_path is not None: if load_path is not None:
if self.rank <= 0: if self.rank <= 0:
logger.info('Loading model for [%s]' % (load_path,)) logger.info('Loading model for [%s]' % (load_path,))
self.load_network(load_path, net, self.opt['path']['strict_load']) self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
if hasattr(net.module, 'network_loaded'): if hasattr(net.module, 'network_loaded'):
net.module.network_loaded() net.module.network_loaded()

View File

@ -97,7 +97,7 @@ class BaseModel():
save_path, os.path.join(self.opt['remote_path'], 'models', save_filename)) save_path, os.path.join(self.opt['remote_path'], 'models', save_filename))
return save_path return save_path
def load_network(self, load_path, network, strict=True): def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): #if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module network = network.module
load_net = torch.load(load_path) load_net = torch.load(load_path)
@ -105,9 +105,15 @@ class BaseModel():
# Support loading torch.save()s for whole models as well as just state_dicts. # Support loading torch.save()s for whole models as well as just state_dicts.
if 'state_dict' in load_net: if 'state_dict' in load_net:
load_net = load_net['state_dict'] load_net = load_net['state_dict']
is_srflow = False
load_net_clean = OrderedDict() # remove unnecessary 'module.' load_net_clean = OrderedDict() # remove unnecessary 'module.'
if pretrain_base_path is not None:
t = load_net
load_net = {}
for k, v in t.items():
if k.startswith(pretrain_base_path):
load_net[k[len(pretrain_base_path):]] = v
for k, v in load_net.items(): for k, v in load_net.items():
if k.startswith('module.'): if k.startswith('module.'):
load_net_clean[k.replace('module.', '')] = v load_net_clean[k.replace('module.', '')] = v

View File

@ -18,7 +18,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])
self.dataset = torchvision.datasets.ImageFolder( self.dataset = torchvision.datasets.ImageFolder(
'F:\\4k6k\\datasets\\images\\imagenet_2017\\val', 'E:\\4k6k\\datasets\\images\\imagenet_2017\\val',
transforms.Compose([ transforms.Compose([
transforms.Resize(256), transforms.Resize(256),
transforms.CenterCrop(224), transforms.CenterCrop(224),
@ -27,7 +27,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
])) ]))
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4) self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
self.masking = opt_get(opt_eval, ['masking'], True) self.masking = opt_get(opt_eval, ['masking'], False)
if self.masking: if self.masking:
self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth', self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth',
kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth', kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',