Implement a few changes to support training BYOL networks

This commit is contained in:
James Betker 2020-12-23 10:50:23 -07:00
parent 2437b33e74
commit 1bbcb96ee8
7 changed files with 95 additions and 25 deletions

View File

@ -35,6 +35,9 @@ class ByolDatasetWrapper(Dataset):
super().__init__() super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset']) self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size'] self.cropped_img_size = opt['crop_size']
self.key1 = opt_get(opt, ['key1'], 'hq')
self.key2 = opt_get(opt, ['key2'], 'lq')
augmentations = [ \ augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2), augs.RandomGrayscale(p=0.2),
@ -49,7 +52,7 @@ class ByolDatasetWrapper(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
item = self.wrapped_dataset[item] item = self.wrapped_dataset[item]
item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)}) item.update({'aug1': self.aug(item[self.key1]).squeeze(dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)})
return item return item
def __len__(self): def __len__(self):

View File

@ -22,6 +22,9 @@ class ImageFolderDataset:
self.scale = opt['scale'] self.scale = opt['scale']
self.paths = opt['paths'] self.paths = opt['paths']
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image
# from the same video source. Search for 'fetch_alt_image' for more info.
self.skip_lq = opt['skip_lq']
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
if not isinstance(self.paths, list): if not isinstance(self.paths, list):
self.paths = [self.paths] self.paths = [self.paths]
@ -110,13 +113,56 @@ class ImageFolderDataset:
dim = hq.shape[0] dim = hq.shape[0]
hs = self.resize_hq([hq]) hs = self.resize_hq([hq])
ls = self.synthesize_lq(hs) if not self.skip_lq:
ls = self.synthesize_lq(hs)
# Convert to torch tensor # Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() if not self.skip_lq:
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
if not self.skip_lq:
out_dict['lq'] = lq
if self.fetch_alt_image:
# This works by assuming a specific filename structure as would produced by ffmpeg. ex:
# 'Candied Walnutsxjktqhr_SYc.webm_00000478.jpg` and
# 'Candied Walnutsxjktqhr_SYc.webm_00000479.jpg` and
# 'Candied Walnutsxjktqhr_SYc.webm_00000480.jpg`
# The basic format is `<anything>%08d.<extension>`. This logic parses off that 8 digit number. If it is
# not found, the 'alt_image' returned is just the current image. If it is found, the algorithm searches for
# an image one number higher. If it is found - it is returned in the 'alt_hq' and 'alt_lq' keys, else the
# current image is put in those keys.
imname_parts = self.image_paths[item]
while '.jpg.jpg' in imname_parts:
imname_parts = imname_parts.replace(".jpg.jpg", ".jpg") # Hack workaround to my own bug.
imname_parts = imname_parts.split('.')
if len(imname_parts) >= 2 and len(imname_parts[-2]) > 8:
try:
imnumber = int(imname_parts[-2][-8:])
# When we're dealing with images in the 1M range, it's straight up faster to attempt to just open
# the file rather than searching the path list. Let the exception handler below do its work.
next_img = self.image_paths[item].replace(str(imnumber), str(imnumber+1))
alt_hq = util.read_img(None, next_img, rgb=True)
alt_hq = self.resize_hq([alt_hq])
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hq[0], (2, 0, 1)))).float()
if not self.skip_lq:
alt_lq = self.synthesize_lq(alt_hq)
alt_lq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq[0], (2, 0, 1)))).float()
except:
alt_hq = hq
if not self.skip_lq:
alt_lq = lq
else:
alt_hq = hq
if not self.skip_lq:
alt_lq = lq
out_dict['alt_hq'] = alt_hq
if not self.skip_lq:
out_dict['alt_lq'] = alt_lq
out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
if self.labeler: if self.labeler:
base_file = self.image_paths[item].replace(self.paths[0], "") base_file = self.image_paths[item].replace(self.paths[0], "")
while base_file.startswith("\\"): while base_file.startswith("\\"):
@ -131,19 +177,20 @@ class ImageFolderDataset:
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'], 'paths': ['F:\\4k6k\\datasets\\images\\youtube\\4k_quote_unquote\\images'],
'weights': [1], 'weights': [1],
'target_size': 512, 'target_size': 256,
'force_multiple': 32, 'force_multiple': 32,
'scale': 2, 'scale': 2,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'], 'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1, 'num_corrupts_per_image': 1,
'corrupt_before_downsize': True, 'corrupt_before_downsize': True,
'labeler': { 'fetch_alt_image': True,
'type': 'patch_labels', #'labeler': {
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json' # 'type': 'patch_labels',
} # 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
#}
} }
ds = ImageFolderDataset(opt) ds = ImageFolderDataset(opt)
@ -152,11 +199,11 @@ if __name__ == '__main__':
for i in range(0, len(ds)): for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds)-1)] o = ds[random.randint(0, len(ds)-1)]
hq = o['hq'] hq = o['hq']
masked = (o['labels_mask'] * .5 + .5) * hq #masked = (o['labels_mask'] * .5 + .5) * hq
import torchvision import torchvision
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,)) torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,))
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,)) torchvision.utils.save_image(o['alt_hq'].unsqueeze(0), "debug/%i_hq_alt.png" % (i,))
if len(o['labels'].unique()) > 1: #if len(o['labels'].unique()) > 1:
randlbl = np.random.choice(o['labels'].unique()[1:]) # randlbl = np.random.choice(o['labels'].unique()[1:])
moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5) # moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl])) # torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))

View File

@ -113,7 +113,10 @@ def read_img(env, path, size=None, rgb=False):
"""read image by cv2 or from lmdb or from a buffer (in which case path=buffer) """read image by cv2 or from lmdb or from a buffer (in which case path=buffer)
return: Numpy float32, HWC, BGR, [0,1]""" return: Numpy float32, HWC, BGR, [0,1]"""
if env is None: # img if env is None: # img
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # Indirect open then process to support unicode files.
stream = open(path, "rb")
bytes = bytearray(stream.read())
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
elif env is 'lmdb': elif env is 'lmdb':
img = _read_img_lmdb(env, path, size) img = _read_img_lmdb(env, path, size)
elif env is 'buffer': elif env is 'buffer':

View File

@ -148,7 +148,7 @@ class NetWrapper(nn.Module):
if self.structural_mlp: if self.structural_mlp:
projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size) projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size)
else: else:
_, dim = hidden.shape _, dim = hidden.flatten(1,-1).shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size) projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden) return projector.to(hidden)

View File

@ -13,15 +13,16 @@ import torch
def main(): def main():
split_img = False split_img = False
opt = {} opt = {}
opt['n_thread'] = 10 opt['n_thread'] = 8
opt['compression_level'] = 90 # JPEG compression quality rating. opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file' opt['dest'] = 'file'
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\other_ns'] opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised']
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised' opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised'
opt['imgsize'] = 512 opt['imgsize'] = 256
#opt['bottom_crop'] = 120
save_folder = opt['save_folder'] save_folder = opt['save_folder']
if not osp.exists(save_folder): if not osp.exists(save_folder):
@ -36,6 +37,7 @@ class TiledDataset(data.Dataset):
self.opt = opt self.opt = opt
input_folder = opt['input_folder'] input_folder = opt['input_folder']
self.images = data_util.get_image_paths('img', input_folder)[0] self.images = data_util.get_image_paths('img', input_folder)[0]
print("Found %i images" % (len(self.images),))
def __getitem__(self, index): def __getitem__(self, index):
return self.get(index) return self.get(index)
@ -43,7 +45,7 @@ class TiledDataset(data.Dataset):
def get(self, index): def get(self, index):
path = self.images[index] path = self.images[index]
basename = osp.basename(path) basename = osp.basename(path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) img = data_util.read_img(None, path)
# Greyscale not supported. # Greyscale not supported.
if img is None: if img is None:
@ -51,6 +53,12 @@ class TiledDataset(data.Dataset):
return None return None
if len(img.shape) == 2: if len(img.shape) == 2:
return None return None
# Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to
# consider these areas of the image.
if 'bottom_crop' in self.opt.keys():
img = img[:-self.opt['bottom_crop'], :, :]
h, w, c = img.shape h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size. # Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < 512: if min(h,w) < 512:
@ -61,7 +69,14 @@ class TiledDataset(data.Dataset):
# Crop the image so that only the center is left, since this is often the most salient part of the image. # Crop the image so that only the center is left, since this is often the most salient part of the image.
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA) img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA)
cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
# I was having some issues with unicode filenames with cv2. Hence using PIL.
# cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = (img * 255).astype(np.uint8)
img = Image.fromarray(img)
img.save(osp.join(self.opt['save_folder'], basename + ".jpg"), "JPEG", quality=self.opt['compression_level'], optimize=True)
return None return None
def __len__(self): def __len__(self):

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_glean.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -125,6 +125,8 @@ def define_G(opt, opt_net, scale=None):
from models.spinenet_arch import SpinenetWithLogits from models.spinenet_arch import SpinenetWithLogits
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'], netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
in_channels=3, use_input_norm=opt_net['use_input_norm']) in_channels=3, use_input_norm=opt_net['use_input_norm'])
elif which_model == 'resnet52':
netG = torchvision.models.resnet50(pretrained=opt_net['pretrained'])
elif which_model == 'glean': elif which_model == 'glean':
from models.glean.glean import GleanGenerator from models.glean.glean import GleanGenerator
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan']) netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])