diff --git a/codes/data/byol_attachment.py b/codes/data/byol_attachment.py index f5dda78b..5574f327 100644 --- a/codes/data/byol_attachment.py +++ b/codes/data/byol_attachment.py @@ -35,6 +35,9 @@ class ByolDatasetWrapper(Dataset): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) self.cropped_img_size = opt['crop_size'] + self.key1 = opt_get(opt, ['key1'], 'hq') + self.key2 = opt_get(opt, ['key2'], 'lq') + augmentations = [ \ RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), @@ -49,7 +52,7 @@ class ByolDatasetWrapper(Dataset): def __getitem__(self, item): item = self.wrapped_dataset[item] - item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)}) + item.update({'aug1': self.aug(item[self.key1]).squeeze(dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)}) return item def __len__(self): diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 00c61832..2bc1a984 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -22,6 +22,9 @@ class ImageFolderDataset: self.scale = opt['scale'] self.paths = opt['paths'] self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False + self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image + # from the same video source. Search for 'fetch_alt_image' for more info. + self.skip_lq = opt['skip_lq'] assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. if not isinstance(self.paths, list): self.paths = [self.paths] @@ -110,13 +113,56 @@ class ImageFolderDataset: dim = hq.shape[0] hs = self.resize_hq([hq]) - ls = self.synthesize_lq(hs) + if not self.skip_lq: + ls = self.synthesize_lq(hs) # Convert to torch tensor hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() - lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() + if not self.skip_lq: + lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() + + out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} + if not self.skip_lq: + out_dict['lq'] = lq + + if self.fetch_alt_image: + # This works by assuming a specific filename structure as would produced by ffmpeg. ex: + # 'Candied Walnutsxjktqhr_SYc.webm_00000478.jpg` and + # 'Candied Walnutsxjktqhr_SYc.webm_00000479.jpg` and + # 'Candied Walnutsxjktqhr_SYc.webm_00000480.jpg` + # The basic format is `%08d.`. This logic parses off that 8 digit number. If it is + # not found, the 'alt_image' returned is just the current image. If it is found, the algorithm searches for + # an image one number higher. If it is found - it is returned in the 'alt_hq' and 'alt_lq' keys, else the + # current image is put in those keys. + + imname_parts = self.image_paths[item] + while '.jpg.jpg' in imname_parts: + imname_parts = imname_parts.replace(".jpg.jpg", ".jpg") # Hack workaround to my own bug. + imname_parts = imname_parts.split('.') + if len(imname_parts) >= 2 and len(imname_parts[-2]) > 8: + try: + imnumber = int(imname_parts[-2][-8:]) + # When we're dealing with images in the 1M range, it's straight up faster to attempt to just open + # the file rather than searching the path list. Let the exception handler below do its work. + next_img = self.image_paths[item].replace(str(imnumber), str(imnumber+1)) + alt_hq = util.read_img(None, next_img, rgb=True) + alt_hq = self.resize_hq([alt_hq]) + alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hq[0], (2, 0, 1)))).float() + if not self.skip_lq: + alt_lq = self.synthesize_lq(alt_hq) + alt_lq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq[0], (2, 0, 1)))).float() + except: + alt_hq = hq + if not self.skip_lq: + alt_lq = lq + else: + alt_hq = hq + if not self.skip_lq: + alt_lq = lq + out_dict['alt_hq'] = alt_hq + if not self.skip_lq: + out_dict['alt_lq'] = alt_lq - out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} if self.labeler: base_file = self.image_paths[item].replace(self.paths[0], "") while base_file.startswith("\\"): @@ -131,19 +177,20 @@ class ImageFolderDataset: if __name__ == '__main__': opt = { 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'], + 'paths': ['F:\\4k6k\\datasets\\images\\youtube\\4k_quote_unquote\\images'], 'weights': [1], - 'target_size': 512, + 'target_size': 256, 'force_multiple': 32, 'scale': 2, 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], 'random_corruptions': ['noise-5', 'none'], 'num_corrupts_per_image': 1, 'corrupt_before_downsize': True, - 'labeler': { - 'type': 'patch_labels', - 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json' - } + 'fetch_alt_image': True, + #'labeler': { + # 'type': 'patch_labels', + # 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json' + #} } ds = ImageFolderDataset(opt) @@ -152,11 +199,11 @@ if __name__ == '__main__': for i in range(0, len(ds)): o = ds[random.randint(0, len(ds)-1)] hq = o['hq'] - masked = (o['labels_mask'] * .5 + .5) * hq + #masked = (o['labels_mask'] * .5 + .5) * hq import torchvision torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,)) - #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,)) - if len(o['labels'].unique()) > 1: - randlbl = np.random.choice(o['labels'].unique()[1:]) - moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5) - torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl])) \ No newline at end of file + torchvision.utils.save_image(o['alt_hq'].unsqueeze(0), "debug/%i_hq_alt.png" % (i,)) + #if len(o['labels'].unique()) > 1: + # randlbl = np.random.choice(o['labels'].unique()[1:]) + # moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5) + # torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl])) \ No newline at end of file diff --git a/codes/data/util.py b/codes/data/util.py index 6e487c79..521881b8 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -113,7 +113,10 @@ def read_img(env, path, size=None, rgb=False): """read image by cv2 or from lmdb or from a buffer (in which case path=buffer) return: Numpy float32, HWC, BGR, [0,1]""" if env is None: # img - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + # Indirect open then process to support unicode files. + stream = open(path, "rb") + bytes = bytearray(stream.read()) + img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED) elif env is 'lmdb': img = _read_img_lmdb(env, path, size) elif env is 'buffer': diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py index 536faa85..3727e296 100644 --- a/codes/models/byol/byol_model_wrapper.py +++ b/codes/models/byol/byol_model_wrapper.py @@ -148,7 +148,7 @@ class NetWrapper(nn.Module): if self.structural_mlp: projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size) else: - _, dim = hidden.shape + _, dim = hidden.flatten(1,-1).shape projector = MLP(dim, self.projection_size, self.projection_hidden_size) return projector.to(hidden) diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py index 4f60cab8..57d61668 100644 --- a/codes/scripts/extract_square_images.py +++ b/codes/scripts/extract_square_images.py @@ -13,15 +13,16 @@ import torch def main(): split_img = False opt = {} - opt['n_thread'] = 10 + opt['n_thread'] = 8 opt['compression_level'] = 90 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. opt['dest'] = 'file' - opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\other_ns'] - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised' - opt['imgsize'] = 512 + opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'] + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised' + opt['imgsize'] = 256 + #opt['bottom_crop'] = 120 save_folder = opt['save_folder'] if not osp.exists(save_folder): @@ -36,6 +37,7 @@ class TiledDataset(data.Dataset): self.opt = opt input_folder = opt['input_folder'] self.images = data_util.get_image_paths('img', input_folder)[0] + print("Found %i images" % (len(self.images),)) def __getitem__(self, index): return self.get(index) @@ -43,7 +45,7 @@ class TiledDataset(data.Dataset): def get(self, index): path = self.images[index] basename = osp.basename(path) - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + img = data_util.read_img(None, path) # Greyscale not supported. if img is None: @@ -51,6 +53,12 @@ class TiledDataset(data.Dataset): return None if len(img.shape) == 2: return None + + # Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to + # consider these areas of the image. + if 'bottom_crop' in self.opt.keys(): + img = img[:-self.opt['bottom_crop'], :, :] + h, w, c = img.shape # Uncomment to filter any image that doesnt meet a threshold size. if min(h,w) < 512: @@ -61,7 +69,14 @@ class TiledDataset(data.Dataset): # Crop the image so that only the center is left, since this is often the most salient part of the image. img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA) - cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) + + # I was having some issues with unicode filenames with cv2. Hence using PIL. + # cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = (img * 255).astype(np.uint8) + img = Image.fromarray(img) + img.save(osp.join(self.opt['save_folder'], basename + ".jpg"), "JPEG", quality=self.opt['compression_level'], optimize=True) return None def __len__(self): diff --git a/codes/train.py b/codes/train.py index 1871f117..5262d5b7 100644 --- a/codes/train.py +++ b/codes/train.py @@ -293,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_glean.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index 82568867..422965fe 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -125,6 +125,8 @@ def define_G(opt, opt_net, scale=None): from models.spinenet_arch import SpinenetWithLogits netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'], in_channels=3, use_input_norm=opt_net['use_input_norm']) + elif which_model == 'resnet52': + netG = torchvision.models.resnet50(pretrained=opt_net['pretrained']) elif which_model == 'glean': from models.glean.glean import GleanGenerator netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])