Implement a few changes to support training BYOL networks

This commit is contained in:
James Betker 2020-12-23 10:50:23 -07:00
parent 2437b33e74
commit 1bbcb96ee8
7 changed files with 95 additions and 25 deletions

View File

@ -35,6 +35,9 @@ class ByolDatasetWrapper(Dataset):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size']
self.key1 = opt_get(opt, ['key1'], 'hq')
self.key2 = opt_get(opt, ['key2'], 'lq')
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
@ -49,7 +52,7 @@ class ByolDatasetWrapper(Dataset):
def __getitem__(self, item):
item = self.wrapped_dataset[item]
item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)})
item.update({'aug1': self.aug(item[self.key1]).squeeze(dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)})
return item
def __len__(self):

View File

@ -22,6 +22,9 @@ class ImageFolderDataset:
self.scale = opt['scale']
self.paths = opt['paths']
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image
# from the same video source. Search for 'fetch_alt_image' for more info.
self.skip_lq = opt['skip_lq']
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
if not isinstance(self.paths, list):
self.paths = [self.paths]
@ -110,13 +113,56 @@ class ImageFolderDataset:
dim = hq.shape[0]
hs = self.resize_hq([hq])
ls = self.synthesize_lq(hs)
if not self.skip_lq:
ls = self.synthesize_lq(hs)
# Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
if not self.skip_lq:
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
if not self.skip_lq:
out_dict['lq'] = lq
if self.fetch_alt_image:
# This works by assuming a specific filename structure as would produced by ffmpeg. ex:
# 'Candied Walnutsxjktqhr_SYc.webm_00000478.jpg` and
# 'Candied Walnutsxjktqhr_SYc.webm_00000479.jpg` and
# 'Candied Walnutsxjktqhr_SYc.webm_00000480.jpg`
# The basic format is `<anything>%08d.<extension>`. This logic parses off that 8 digit number. If it is
# not found, the 'alt_image' returned is just the current image. If it is found, the algorithm searches for
# an image one number higher. If it is found - it is returned in the 'alt_hq' and 'alt_lq' keys, else the
# current image is put in those keys.
imname_parts = self.image_paths[item]
while '.jpg.jpg' in imname_parts:
imname_parts = imname_parts.replace(".jpg.jpg", ".jpg") # Hack workaround to my own bug.
imname_parts = imname_parts.split('.')
if len(imname_parts) >= 2 and len(imname_parts[-2]) > 8:
try:
imnumber = int(imname_parts[-2][-8:])
# When we're dealing with images in the 1M range, it's straight up faster to attempt to just open
# the file rather than searching the path list. Let the exception handler below do its work.
next_img = self.image_paths[item].replace(str(imnumber), str(imnumber+1))
alt_hq = util.read_img(None, next_img, rgb=True)
alt_hq = self.resize_hq([alt_hq])
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hq[0], (2, 0, 1)))).float()
if not self.skip_lq:
alt_lq = self.synthesize_lq(alt_hq)
alt_lq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq[0], (2, 0, 1)))).float()
except:
alt_hq = hq
if not self.skip_lq:
alt_lq = lq
else:
alt_hq = hq
if not self.skip_lq:
alt_lq = lq
out_dict['alt_hq'] = alt_hq
if not self.skip_lq:
out_dict['alt_lq'] = alt_lq
out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
if self.labeler:
base_file = self.image_paths[item].replace(self.paths[0], "")
while base_file.startswith("\\"):
@ -131,19 +177,20 @@ class ImageFolderDataset:
if __name__ == '__main__':
opt = {
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'],
'paths': ['F:\\4k6k\\datasets\\images\\youtube\\4k_quote_unquote\\images'],
'weights': [1],
'target_size': 512,
'target_size': 256,
'force_multiple': 32,
'scale': 2,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': True,
'labeler': {
'type': 'patch_labels',
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
}
'fetch_alt_image': True,
#'labeler': {
# 'type': 'patch_labels',
# 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
#}
}
ds = ImageFolderDataset(opt)
@ -152,11 +199,11 @@ if __name__ == '__main__':
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds)-1)]
hq = o['hq']
masked = (o['labels_mask'] * .5 + .5) * hq
#masked = (o['labels_mask'] * .5 + .5) * hq
import torchvision
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,))
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,))
if len(o['labels'].unique()) > 1:
randlbl = np.random.choice(o['labels'].unique()[1:])
moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))
torchvision.utils.save_image(o['alt_hq'].unsqueeze(0), "debug/%i_hq_alt.png" % (i,))
#if len(o['labels'].unique()) > 1:
# randlbl = np.random.choice(o['labels'].unique()[1:])
# moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
# torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))

View File

@ -113,7 +113,10 @@ def read_img(env, path, size=None, rgb=False):
"""read image by cv2 or from lmdb or from a buffer (in which case path=buffer)
return: Numpy float32, HWC, BGR, [0,1]"""
if env is None: # img
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
# Indirect open then process to support unicode files.
stream = open(path, "rb")
bytes = bytearray(stream.read())
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
elif env is 'lmdb':
img = _read_img_lmdb(env, path, size)
elif env is 'buffer':

View File

@ -148,7 +148,7 @@ class NetWrapper(nn.Module):
if self.structural_mlp:
projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size)
else:
_, dim = hidden.shape
_, dim = hidden.flatten(1,-1).shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)

View File

@ -13,15 +13,16 @@ import torch
def main():
split_img = False
opt = {}
opt['n_thread'] = 10
opt['n_thread'] = 8
opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file'
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\other_ns']
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised'
opt['imgsize'] = 512
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised']
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised'
opt['imgsize'] = 256
#opt['bottom_crop'] = 120
save_folder = opt['save_folder']
if not osp.exists(save_folder):
@ -36,6 +37,7 @@ class TiledDataset(data.Dataset):
self.opt = opt
input_folder = opt['input_folder']
self.images = data_util.get_image_paths('img', input_folder)[0]
print("Found %i images" % (len(self.images),))
def __getitem__(self, index):
return self.get(index)
@ -43,7 +45,7 @@ class TiledDataset(data.Dataset):
def get(self, index):
path = self.images[index]
basename = osp.basename(path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = data_util.read_img(None, path)
# Greyscale not supported.
if img is None:
@ -51,6 +53,12 @@ class TiledDataset(data.Dataset):
return None
if len(img.shape) == 2:
return None
# Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to
# consider these areas of the image.
if 'bottom_crop' in self.opt.keys():
img = img[:-self.opt['bottom_crop'], :, :]
h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < 512:
@ -61,7 +69,14 @@ class TiledDataset(data.Dataset):
# Crop the image so that only the center is left, since this is often the most salient part of the image.
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA)
cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
# I was having some issues with unicode filenames with cv2. Hence using PIL.
# cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = (img * 255).astype(np.uint8)
img = Image.fromarray(img)
img.save(osp.join(self.opt['save_folder'], basename + ".jpg"), "JPEG", quality=self.opt['compression_level'], optimize=True)
return None
def __len__(self):

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_glean.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -125,6 +125,8 @@ def define_G(opt, opt_net, scale=None):
from models.spinenet_arch import SpinenetWithLogits
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
in_channels=3, use_input_norm=opt_net['use_input_norm'])
elif which_model == 'resnet52':
netG = torchvision.models.resnet50(pretrained=opt_net['pretrained'])
elif which_model == 'glean':
from models.glean.glean import GleanGenerator
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])