forked from mrq/DL-Art-School
Fixes and mods to support training classifiers on imagenet
This commit is contained in:
parent
f129eaa39e
commit
45bc76ba92
|
@ -4,6 +4,9 @@ import torchvision.transforms as T
|
||||||
from torchvision import datasets
|
from torchvision import datasets
|
||||||
|
|
||||||
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
|
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
class TorchDataset(Dataset):
|
class TorchDataset(Dataset):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
DATASET_MAP = {
|
DATASET_MAP = {
|
||||||
|
@ -14,7 +17,7 @@ class TorchDataset(Dataset):
|
||||||
"imagefolder": datasets.ImageFolder
|
"imagefolder": datasets.ImageFolder
|
||||||
}
|
}
|
||||||
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
if opt['train']:
|
if opt_get(opt, ['random_crop'], False):
|
||||||
transforms = [
|
transforms = [
|
||||||
T.RandomResizedCrop(opt['image_size']),
|
T.RandomResizedCrop(opt['image_size']),
|
||||||
T.RandomHorizontalFlip(),
|
T.RandomHorizontalFlip(),
|
||||||
|
@ -23,8 +26,9 @@ class TorchDataset(Dataset):
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
transforms = [
|
transforms = [
|
||||||
T.Resize(opt['val_resize']),
|
T.Resize(opt['image_size']),
|
||||||
T.CenterCrop(opt['image_size']),
|
T.CenterCrop(opt['image_size']),
|
||||||
|
T.RandomHorizontalFlip(),
|
||||||
T.ToTensor(),
|
T.ToTensor(),
|
||||||
normalize,
|
normalize,
|
||||||
]
|
]
|
||||||
|
|
|
@ -35,18 +35,22 @@ class ZipFileDataset(torch.utils.data.Dataset):
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
fname = self.all_files[i]
|
try:
|
||||||
out = {
|
fname = self.all_files[i]
|
||||||
'hq': self.load_image(fname),
|
out = {
|
||||||
'HQ_path': fname,
|
'hq': self.load_image(fname),
|
||||||
'has_alt': self.paired_mode
|
'HQ_path': fname,
|
||||||
}
|
'has_alt': self.paired_mode
|
||||||
if self.paired_mode:
|
}
|
||||||
if fname.endswith('0.jpg'):
|
if self.paired_mode:
|
||||||
aname = fname.replace('0.jpg', '1.jpg')
|
if fname.endswith('0.jpg'):
|
||||||
else:
|
aname = fname.replace('0.jpg', '1.jpg')
|
||||||
aname = fname.replace('1.jpg', '0.jpg')
|
else:
|
||||||
out['alt_hq'] = self.load_image(aname)
|
aname = fname.replace('1.jpg', '0.jpg')
|
||||||
|
out['alt_hq'] = self.load_image(aname)
|
||||||
|
except:
|
||||||
|
print(f"Error loading {fname} from zipfile. Attempting to recover by loading next element.")
|
||||||
|
return self[i+1]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -14,17 +14,19 @@ def main():
|
||||||
split_img = False
|
split_img = False
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 5
|
opt['n_thread'] = 5
|
||||||
opt['compression_level'] = 95 # JPEG compression quality rating.
|
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
|
||||||
|
|
||||||
opt['dest'] = 'file'
|
opt['dest'] = 'file'
|
||||||
opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats']
|
opt['input_folder'] = ['E:\\4k6k\datasets\\ns_images\\imagesets\\imageset_256_masked']
|
||||||
opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\256_4_by_3'
|
opt['save_folder'] = 'E:\\4k6k\datasets\\ns_images\\imagesets\\imageset_128_masked'
|
||||||
opt['imgsize'] = (256,192)
|
opt['imgsize'] = (128,128)
|
||||||
opt['bottom_crop'] = 0
|
opt['bottom_crop'] = 0
|
||||||
opt['keep_folder'] = False
|
opt['keep_folder'] = False
|
||||||
|
|
||||||
|
#opt['format'] = 'jpg'
|
||||||
|
#opt['cv2_write_options'] = [cv2.IMWRITE_JPEG_QUALITY, 95]
|
||||||
|
opt['format'] = 'png'
|
||||||
|
opt['cv2_write_options'] = [cv2.IMWRITE_PNG_COMPRESSION, 9]
|
||||||
|
|
||||||
save_folder = opt['save_folder']
|
save_folder = opt['save_folder']
|
||||||
if not osp.exists(save_folder):
|
if not osp.exists(save_folder):
|
||||||
os.makedirs(save_folder)
|
os.makedirs(save_folder)
|
||||||
|
@ -93,7 +95,7 @@ class TiledDataset(data.Dataset):
|
||||||
pts = os.path.split(pts[0])
|
pts = os.path.split(pts[0])
|
||||||
output_folder = osp.join(self.opt['save_folder'], pts[-1])
|
output_folder = osp.join(self.opt['save_folder'], pts[-1])
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
cv2.imwrite(osp.join(output_folder, basename.replace('.webp', '.jpg')), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
cv2.imwrite(osp.join(output_folder, basename.replace('.webp', self.opt['format'])), img, self.opt['cv2_write_options'])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -295,7 +295,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_imagenet_yt.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_resnet50_yt_pretrained.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -371,7 +371,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if load_path is not None:
|
if load_path is not None:
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
logger.info('Loading model for [%s]' % (load_path,))
|
logger.info('Loading model for [%s]' % (load_path,))
|
||||||
self.load_network(load_path, net, self.opt['path']['strict_load'])
|
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
||||||
if hasattr(net.module, 'network_loaded'):
|
if hasattr(net.module, 'network_loaded'):
|
||||||
net.module.network_loaded()
|
net.module.network_loaded()
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,7 @@ class BaseModel():
|
||||||
save_path, os.path.join(self.opt['remote_path'], 'models', save_filename))
|
save_path, os.path.join(self.opt['remote_path'], 'models', save_filename))
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
def load_network(self, load_path, network, strict=True):
|
def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
|
||||||
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||||
network = network.module
|
network = network.module
|
||||||
load_net = torch.load(load_path)
|
load_net = torch.load(load_path)
|
||||||
|
@ -105,9 +105,15 @@ class BaseModel():
|
||||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||||
if 'state_dict' in load_net:
|
if 'state_dict' in load_net:
|
||||||
load_net = load_net['state_dict']
|
load_net = load_net['state_dict']
|
||||||
|
|
||||||
is_srflow = False
|
|
||||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||||
|
|
||||||
|
if pretrain_base_path is not None:
|
||||||
|
t = load_net
|
||||||
|
load_net = {}
|
||||||
|
for k, v in t.items():
|
||||||
|
if k.startswith(pretrain_base_path):
|
||||||
|
load_net[k[len(pretrain_base_path):]] = v
|
||||||
|
|
||||||
for k, v in load_net.items():
|
for k, v in load_net.items():
|
||||||
if k.startswith('module.'):
|
if k.startswith('module.'):
|
||||||
load_net_clean[k.replace('module.', '')] = v
|
load_net_clean[k.replace('module.', '')] = v
|
||||||
|
|
|
@ -18,7 +18,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225])
|
std=[0.229, 0.224, 0.225])
|
||||||
self.dataset = torchvision.datasets.ImageFolder(
|
self.dataset = torchvision.datasets.ImageFolder(
|
||||||
'F:\\4k6k\\datasets\\images\\imagenet_2017\\val',
|
'E:\\4k6k\\datasets\\images\\imagenet_2017\\val',
|
||||||
transforms.Compose([
|
transforms.Compose([
|
||||||
transforms.Resize(256),
|
transforms.Resize(256),
|
||||||
transforms.CenterCrop(224),
|
transforms.CenterCrop(224),
|
||||||
|
@ -27,7 +27,7 @@ class CategorizationLossEvaluator(evaluator.Evaluator):
|
||||||
]))
|
]))
|
||||||
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
|
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
|
||||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||||
self.masking = opt_get(opt_eval, ['masking'], True)
|
self.masking = opt_get(opt_eval, ['masking'], False)
|
||||||
if self.masking:
|
if self.masking:
|
||||||
self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth',
|
self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth',
|
||||||
kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',
|
kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',
|
||||||
|
|
Loading…
Reference in New Issue
Block a user