From 11155aead4933c6046403e04e3a6a75fa6bbb151 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Dec 2020 20:14:53 -0700 Subject: [PATCH 1/4] Directly use dataset keys This has been a long time coming. Cleans up messy "GT" nomenclature and simplifies ExtensibleTraner.feed_data --- codes/data/README.md | 7 +++++-- codes/data/full_image_dataset.py | 2 +- codes/data/image_folder_dataset.py | 4 ++-- codes/data/multi_frame_dataset.py | 4 ++-- codes/data/multiscale_dataset.py | 2 +- codes/data/paired_frame_dataset.py | 4 ++-- codes/data/single_image_dataset.py | 4 ++-- codes/data/stylegan2_dataset.py | 2 +- codes/data/torch_dataset.py | 2 +- codes/models/ExtensibleTrainer.py | 20 ++++++------------- codes/models/eval/flow_gaussian_nll.py | 4 ++-- codes/models/eval/sr_style.py | 2 +- codes/models/feature_model.py | 4 ++-- codes/process_video.py | 6 +++--- .../compute_fdpl_perceptual_weights.py | 4 ++-- codes/scripts/create_lmdb.py | 10 +++++----- .../scripts/srflow_latent_space_playground.py | 15 +++++++++----- codes/scripts/test_dataloader.py | 4 ++-- codes/scripts/use_discriminator_as_filter.py | 2 +- codes/scripts/use_generator_as_filter.py | 4 ++-- codes/test.py | 4 ++-- codes/train.py | 4 ++-- codes/train2.py | 10 ++++++---- 23 files changed, 63 insertions(+), 61 deletions(-) diff --git a/codes/data/README.md b/codes/data/README.md index 78a65825..be9af646 100644 --- a/codes/data/README.md +++ b/codes/data/README.md @@ -20,12 +20,15 @@ This directory contains several reference datasets which I have used in building 1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this dataset are grouped together with a temporal dimension for working with video data. +1. ImageFolderDataset - Reads raw images from a folder and feeds them into the model. Capable of performing corruptions + on those images like the above. 1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares from the source image and resizing them to the target size recursively until the native resolution is hit. Each recursive step decreases the crop size by a factor of 2. +1. TorchDataset - A wrapper for miscellaneous pytorch datasets (e.g. MNIST, CIFAR, etc) which extracts the images + and reformats them in a way that the DLAS trainer understands. 1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have - generally stopped using this for performance reasons in favor of SingleImageDataset but it is useful for validation - and test so I keep it around. + generally stopped using this for performance reasons and it should be considered deprecated. ## Information about the "chunked" format diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index 6fe2d444..9e121f27 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -323,7 +323,7 @@ class FullImageDataset(data.Dataset): gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) - d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, + d = {'lq': img_LQ, 'hq': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, 'lq_center': lq_center, 'gt_center': gt_center, 'LQ_path': LQ_path, 'GT_path': full_path} return d diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index c2ae96ac..31309edc 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -95,7 +95,7 @@ class ImageFolderDataset: hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() - return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} + return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} if __name__ == '__main__': @@ -118,7 +118,7 @@ if __name__ == '__main__': for i in range(0, len(ds)): o = ds[random.randint(0, len(ds))] #for k, v in o.items(): - k = 'LQ' + k = 'lq' v = o[k] #if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'full' in k: diff --git a/codes/data/multi_frame_dataset.py b/codes/data/multi_frame_dataset.py index c82355c9..ce99b348 100644 --- a/codes/data/multi_frame_dataset.py +++ b/codes/data/multi_frame_dataset.py @@ -57,7 +57,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) - return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} @@ -83,7 +83,7 @@ if __name__ == '__main__': batch = None for i in range(len(ds)): import random - k = 'LQ' + k = 'lq' element = ds[random.randint(0,len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index f93bcb06..6673f6e0 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -87,7 +87,7 @@ class MultiScaleDataset(data.Dataset): patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted] patches_lq = torch.stack(patches_lq, dim=0) - d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path} + d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path} return d def __len__(self): diff --git a/codes/data/paired_frame_dataset.py b/codes/data/paired_frame_dataset.py index 801d4182..49328abc 100644 --- a/codes/data/paired_frame_dataset.py +++ b/codes/data/paired_frame_dataset.py @@ -42,7 +42,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) - return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} @@ -68,7 +68,7 @@ if __name__ == '__main__': batch = None for i in range(len(ds)): import random - k = 'LQ' + k = 'lq' element = ds[random.randint(0,len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) diff --git a/codes/data/single_image_dataset.py b/codes/data/single_image_dataset.py index 2521f792..4048f197 100644 --- a/codes/data/single_image_dataset.py +++ b/codes/data/single_image_dataset.py @@ -36,7 +36,7 @@ class SingleImageDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0) lq_ref = torch.cat([lq_ref, lq_mask], dim=0) - return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long), 'LQ_path': path, 'GT_path': path} @@ -62,7 +62,7 @@ if __name__ == '__main__': for i in range(0, len(ds)): o = ds[random.randint(0, len(ds))] #for k, v in o.items(): - k = 'LQ' + k = 'lq' v = o[k] #if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'full' in k: diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py index 2424f686..b00f591d 100644 --- a/codes/data/stylegan2_dataset.py +++ b/codes/data/stylegan2_dataset.py @@ -98,4 +98,4 @@ class Stylegan2Dataset(data.Dataset): path = self.paths[index] img = Image.open(path) img = self.transform(img) - return {'LQ': img, 'GT': img, 'GT_path': str(path)} + return {'lq': img, 'hq': img, 'GT_path': str(path)} diff --git a/codes/data/torch_dataset.py b/codes/data/torch_dataset.py index 58ed34de..01875bfb 100644 --- a/codes/data/torch_dataset.py +++ b/codes/data/torch_dataset.py @@ -24,7 +24,7 @@ class TorchDataset(Dataset): def __getitem__(self, item): underlying_item = self.dataset[item][0] - return {'LQ': underlying_item, 'GT': underlying_item, + return {'lq': underlying_item, 'hq': underlying_item, 'LQ_path': str(item), 'GT_path': str(item)} def __len__(self): diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 451b5e1d..2f9c147b 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -4,7 +4,6 @@ import os import torch from torch.nn.parallel import DataParallel import torch.nn as nn -from apex.parallel import DistributedDataParallel import models.lr_scheduler as lr_scheduler import models.networks as networks @@ -106,6 +105,8 @@ class ExtensibleTrainer(BaseModel): all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] for anet in all_networks: if opt['dist']: + # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. + from apex.parallel import DistributedDataParallel dnet = DistributedDataParallel(anet, delay_allreduce=True) else: dnet = DataParallel(anet, device_ids=opt['gpu_ids']) @@ -160,18 +161,9 @@ class ExtensibleTrainer(BaseModel): o.zero_grad() torch.cuda.empty_cache() - self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)] - if need_GT: - self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] - self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)] - else: - self.hq = self.lq - self.ref = self.lq - - self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} + self.dstate = {} for k, v in data.items(): - if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor): self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)] def optimize_parameters(self, step): @@ -328,8 +320,8 @@ class ExtensibleTrainer(BaseModel): def get_current_visuals(self, need_GT=True): # Conforms to an archaic format from MMSR. - return {'LQ': self.eval_state['lq'][0].float().cpu(), - 'GT': self.eval_state['hq'][0].float().cpu(), + return {'lq': self.eval_state['lq'][0].float().cpu(), + 'hq': self.eval_state['hq'][0].float().cpu(), 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} def print_network(self): diff --git a/codes/models/eval/flow_gaussian_nll.py b/codes/models/eval/flow_gaussian_nll.py index eed85622..31e7297e 100644 --- a/codes/models/eval/flow_gaussian_nll.py +++ b/codes/models/eval/flow_gaussian_nll.py @@ -30,8 +30,8 @@ class FlowGaussianNll(evaluator.Evaluator): print("Evaluating FlowGaussianNll..") for batch in tqdm(self.dataloader): dev = self.env['device'] - z, _, _ = self.model(gt=batch['GT'].to(dev), - lr=batch['LQ'].to(dev), + z, _, _ = self.model(gt=batch['hq'].to(dev), + lr=batch['lq'].to(dev), epses=[], reverse=False, add_gt_noise=False) diff --git a/codes/models/eval/sr_style.py b/codes/models/eval/sr_style.py index 45ca70ee..b44d5dcf 100644 --- a/codes/models/eval/sr_style.py +++ b/codes/models/eval/sr_style.py @@ -39,7 +39,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator): counter = 0 for batch in self.sampler: noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) - batch_hq = [e['GT'] for e in batch] + batch_hq = [e['hq'] for e in batch] batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device']) resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area") embedding = embedding_generator(resized_batch) diff --git a/codes/models/feature_model.py b/codes/models/feature_model.py index dc9d9dd5..74d82e23 100644 --- a/codes/models/feature_model.py +++ b/codes/models/feature_model.py @@ -66,9 +66,9 @@ class FeatureModel(BaseModel): self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): - self.var_L = data['LQ'].to(self.device) # LQ + self.var_L = data['lq'].to(self.device) # LQ if need_GT: - self.real_H = data['GT'].to(self.device) # GT + self.real_H = data['hq'].to(self.device) # GT def optimize_parameters(self, step): self.optimizer_G.zero_grad() diff --git a/codes/process_video.py b/codes/process_video.py index fd376966..7a53c6e1 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -88,7 +88,7 @@ class FfmpegBackedVideoDataset(data.Dataset): img_LQ = lq_template ref = ref_template - return {'LQ': img_LQ, 'lq_fullsize_ref': ref, + return {'lq': img_LQ, 'lq_fullsize_ref': ref, 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } def __len__(self): @@ -159,8 +159,8 @@ if __name__ == "__main__": need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True if recurrent_mode and first_frame: - b, c, h, w = data['LQ'].shape - recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device) + b, c, h, w = data['lq'].shape + recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device) # Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of. if 'recurrent_hr_generator' in opt.keys(): recurrent_gen = model.env['generators']['generator'] diff --git a/codes/scripts/compute_fdpl_perceptual_weights.py b/codes/scripts/compute_fdpl_perceptual_weights.py index 1a6506b2..3b411721 100644 --- a/codes/scripts/compute_fdpl_perceptual_weights.py +++ b/codes/scripts/compute_fdpl_perceptual_weights.py @@ -40,8 +40,8 @@ if __name__ == '__main__': break sampled += 1 - im = rgb2ycbcr(train_data['GT'].double()) - im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(), + im = rgb2ycbcr(train_data['hq'].double()) + im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(), size=im.shape[2:], mode="bicubic", align_corners=False)) patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True) diff --git a/codes/scripts/create_lmdb.py b/codes/scripts/create_lmdb.py index c44b5c78..7b6d5de4 100644 --- a/codes/scripts/create_lmdb.py +++ b/codes/scripts/create_lmdb.py @@ -16,7 +16,7 @@ import utils.util as util # noqa: E402 def main(): dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test - mode = 'GT' # used for vimeo90k and REDS datasets + mode = 'hq' # used for vimeo90k and REDS datasets # vimeo90k: GT | LR | flow # REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp # train_sharp_flowx4 @@ -159,7 +159,7 @@ def vimeo90k(mode): read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False - if mode == 'GT': + if mode == 'hq': img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' @@ -204,7 +204,7 @@ def vimeo90k(mode): keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) all_img_list = sorted(all_img_list) keys = sorted(keys) - if mode == 'GT': # only read the 4th frame for the GT mode + if mode == 'hq': # only read the 4th frame for the GT mode print('Only keep the 4th frame.') all_img_list = [v for v in all_img_list if v.endswith('im4.png')] keys = [v for v in keys if v.endswith('_4')] @@ -255,9 +255,9 @@ def vimeo90k(mode): #### create meta information meta_info = {} - if mode == 'GT': + if mode == 'hq': meta_info['name'] = 'Vimeo90K_train_GT' - elif mode == 'LR': + elif mode == 'lq': meta_info['name'] = 'Vimeo90K_train_LR' elif mode == 'flow': meta_info['name'] = 'Vimeo90K_train_flowx4' diff --git a/codes/scripts/srflow_latent_space_playground.py b/codes/scripts/srflow_latent_space_playground.py index 3fe89577..7dd31902 100644 --- a/codes/scripts/srflow_latent_space_playground.py +++ b/codes/scripts/srflow_latent_space_playground.py @@ -163,7 +163,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True srg_analyze = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt @@ -180,10 +180,10 @@ if __name__ == "__main__": gen = model.networks['generator'] gen.eval() - mode = "temperature" # temperature | restore | latent_transfer | feed_through + mode = "restore" # temperature | restore | latent_transfer | feed_through #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" - #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" - imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" + imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" + #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" scale = 2 resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. temperature = 1 @@ -224,6 +224,11 @@ if __name__ == "__main__": t = image_2_tensor(img_file).to(model.env['device']) if resample_factor != 1: t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic") + # Ensure the input image is a factor of 16. + _, _, h, w = t.shape + h = 16 * (h // 16) + w = 16 * (w // 16) + t = t[:, :, :h, :w] resample_img = t # Fetch the latent metrics & latents for each image we are resampling. @@ -255,6 +260,6 @@ if __name__ == "__main__": for j in range(len(lats)): path = os.path.join(output_path, "%i_%i" % (im_it, j)) os.makedirs(path, exist_ok=True) - torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it))) + torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it))) create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"), path, [torch.zeros_like(l) for l in lats[j]], lats[j]) diff --git a/codes/scripts/test_dataloader.py b/codes/scripts/test_dataloader.py index 73df5242..642c6ef7 100644 --- a/codes/scripts/test_dataloader.py +++ b/codes/scripts/test_dataloader.py @@ -85,8 +85,8 @@ def main(): if dataset == 'REDS' or dataset == 'Vimeo90K': LQs = data['LQs'] else: - LQ = data['LQ'] - GT = data['GT'] + LQ = data['lq'] + GT = data['hq'] if dataset == 'REDS' or dataset == 'Vimeo90K': for j in range(LQs.size(1)): diff --git a/codes/scripts/use_discriminator_as_filter.py b/codes/scripts/use_discriminator_as_filter.py index 94a69eb3..14cb6887 100644 --- a/codes/scripts/use_discriminator_as_filter.py +++ b/codes/scripts/use_discriminator_as_filter.py @@ -68,6 +68,6 @@ if __name__ == "__main__": # removed += 1 imname = osp.basename(data['GT_path'][i]) if results[i]-dataset_mean > 1: - torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) + torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname)) print("Removed %i/%i images" % (removed, len(test_set))) \ No newline at end of file diff --git a/codes/scripts/use_generator_as_filter.py b/codes/scripts/use_generator_as_filter.py index 5f1a4a94..26718702 100644 --- a/codes/scripts/use_generator_as_filter.py +++ b/codes/scripts/use_generator_as_filter.py @@ -66,7 +66,7 @@ if __name__ == "__main__": model.test() gen = model.eval_state['gen'][0].to(model.env['device']) feagen = netF(gen) - feareal = netF(data['GT'].to(model.env['device'])) + feareal = netF(data['hq'].to(model.env['device'])) losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3)) means.append(torch.mean(losses).item()) #print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses)) @@ -76,6 +76,6 @@ if __name__ == "__main__": removed += 1 #imname = osp.basename(data['GT_path'][i]) #if losses[i] < 25000: - # torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) + # torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname)) print("Removed %i/%i images" % (removed, len(test_set))) \ No newline at end of file diff --git a/codes/test.py b/codes/test.py index 71f58c85..584d0dd7 100644 --- a/codes/test.py +++ b/codes/test.py @@ -41,9 +41,9 @@ def forward_pass(model, output_dir, alteration_suffix=''): save_img_path = osp.join(output_dir, img_name + '.png') if need_GT: - fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i]) + fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i]) psnr_sr = util.tensor2img(visuals[i]) - psnr_gt = util.tensor2img(data['GT'][i]) + psnr_gt = util.tensor2img(data['hq'][i]) psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt) util.save_img(sr_img, save_img_path) diff --git a/codes/train.py b/codes/train.py index aa124dd3..08e97474 100644 --- a/codes/train.py +++ b/codes/train.py @@ -231,13 +231,13 @@ class Trainer: sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 # calculate PSNR if self.val_compute_psnr: - gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + gt_img = util.tensor2img(visuals['hq'][b]) # uint8 sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) # calculate fea loss if self.val_compute_fea: - avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) + avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b]) # Save SR images for reference img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) diff --git a/codes/train2.py b/codes/train2.py index 95e8c32d..06e3f0eb 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -3,6 +3,8 @@ import math import argparse import random import logging + +import torchvision from tqdm import tqdm import torch @@ -231,18 +233,18 @@ class Trainer: sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 # calculate PSNR if self.val_compute_psnr: - gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + gt_img = util.tensor2img(visuals['hq'][b]) # uint8 sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) # calculate fea loss if self.val_compute_fea: - avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) + avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b]) # Save SR images for reference img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) save_img_path = os.path.join(img_dir, img_base_name) - util.save_img(sr_img, save_img_path) + torchvision.utils.save_image(visuals['rlt'], save_img_path) avg_psnr = avg_psnr / idx avg_fea_loss = avg_fea_loss / idx @@ -291,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() From 20a09cb31bd99d73410c3f887c1b535744029da9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Dec 2020 20:17:37 -0700 Subject: [PATCH 2/4] #pycharm ad I swear they aren't paying me --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index ad013b5f..74925f7d 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,11 @@ TBC.. ## User Guide TBC +### Development Environment +If you aren't already using [Pycharm](https://www.jetbrains.com/pycharm/) - now is the time to try it out. This project was built in Pycharm and comes with +an IDEA project for you to get started with. I've done all of my development on this repo in this IDE and lean heavily +on its incredible debugger. It's free. Try it out. You won't be sorry. + ### Dataset Preparation DLAS comes with some Dataset instances that I have created for my own use. Unless you want to use one of the recipes above, you'll need to provide your own. Here is how to add your own Dataset: From 88fc049c8d9f65be1954a04a7317ed7000d2d668 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Dec 2020 20:30:36 -0700 Subject: [PATCH 3/4] spinenet latent playground! --- codes/data/image_corruptor.py | 2 +- codes/data/image_folder_dataset.py | 2 +- codes/models/archs/spinenet_arch.py | 2 +- codes/scripts/byol_spinenet_playground.py | 94 +++++++++++++++++++++++ 4 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 codes/scripts/byol_spinenet_playground.py diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index a613175f..323ddd09 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -9,7 +9,7 @@ from io import BytesIO # options. class ImageCorruptor: def __init__(self, opt): - self.fixed_corruptions = opt['fixed_corruptions'] + self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 if self.num_corrupts == 0: return diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 31309edc..04aff28c 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -95,7 +95,7 @@ class ImageFolderDataset: hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() - return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} + return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} if __name__ == '__main__': diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index 85fb71dd..3e4924f5 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -209,7 +209,7 @@ class SpineNet(nn.Module): def __init__(self, arch, in_channels=3, - output_level=[3, 4, 5, 6, 7], + output_level=[3, 4], conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), zero_init_residual=True, diff --git a/codes/scripts/byol_spinenet_playground.py b/codes/scripts/byol_spinenet_playground.py new file mode 100644 index 00000000..498fa1c2 --- /dev/null +++ b/codes/scripts/byol_spinenet_playground.py @@ -0,0 +1,94 @@ +import os +import shutil + +import torch +import torch.nn as nn +import torchvision +from torch.utils.data import DataLoader +from tqdm import tqdm + +from data.image_folder_dataset import ImageFolderDataset +from models.archs.spinenet_arch import SpineNet + + +# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved +# and the distance is computed across the channel dimension. +def structural_euc_dist(x, y): + diff = torch.square(x - y) + sum = torch.sum(diff, dim=1) + return torch.sqrt(sum) + + +def cosine_similarity(x, y): + return nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. + + +def im_norm(x): + return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 + + +def get_image_folder_dataloader(batch_size, num_workers): + dataset_opt = { + 'name': 'amalgam', + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + 'weights': [1], + 'target_size': 512, + 'force_multiple': 32, + 'scale': 1 + } + dataset = ImageFolderDataset(dataset_opt) + return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + + +def create_latent_database(model): + batch_size = 8 + num_workers = 1 + output_path = '../../results/byol_spinenet_latents/' + + os.makedirs(output_path, exist_ok=True) + dataloader = get_image_folder_dataloader(batch_size, num_workers) + id = 0 + latent_dict = {} + for batch in tqdm(dataloader): + hq = batch['hq'].to('cuda:1') + latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. + for b in range(latent.shape[0]): + shutil.copy(batch[b]['HQ_path'], os.path.join(output_path, "%i.jpg" % (id,))) + latent_dict[id] = latent[b].detach().cpu() + if id % 100 == 0: + print("Saving checkpoint..") + torch.save(latent_dict, "latent_dict.pth") + id += 1 + + +def explore_latent_results(model): + batch_size = 8 + num_workers = 1 + output_path = '../../results/byol_spinenet_explore_latents/' + + os.makedirs(output_path, exist_ok=True) + dataloader = get_image_folder_dataloader(batch_size, num_workers) + id = 0 + for batch in tqdm(dataloader): + hq = batch['hq'].to('cuda:1') + latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. + # This operation works by computing the distance of every structural index from the center and using that + # as a "heatmap". + b, c, h, w = latent.shape + center = latent[:, :, h//2, w//2].unsqueeze(-1).unsqueeze(-1) + centers = center.repeat(1, 1, h, w) + dist = structural_euc_dist(latent, centers).unsqueeze(1) + dist = im_norm(dist) + torchvision.utils.save_image(dist, os.path.join(output_path, "%i.png" % id)) + id += 1 + + +if __name__ == '__main__': + pretrained_path = '../../experiments/spinenet49_imgset_byol.pth' + + model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1') + model.load_state_dict(torch.load(pretrained_path), strict=True) + model.eval() + + with torch.no_grad(): + explore_latent_results(model) \ No newline at end of file From c0aeaabc31a9d0dd8b55288f9d6dee98a815489b Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 7 Dec 2020 12:49:32 -0700 Subject: [PATCH 4/4] Spinenet playground --- codes/data/image_folder_dataset.py | 2 +- codes/scripts/byol_spinenet_playground.py | 132 ++++++++++++++++++++-- 2 files changed, 123 insertions(+), 11 deletions(-) diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 04aff28c..1ee85dab 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -29,7 +29,7 @@ class ImageFolderDataset: self.weights = opt['weights'] # Just scan the given directory for images of standard types. - supported_types = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG', 'gif', 'GIF'] + supported_types = ['jpg', 'jpeg', 'png', 'gif'] self.image_paths = [] for path, weight in zip(self.paths, self.weights): cache_path = os.path.join(path, 'cache.pth') diff --git a/codes/scripts/byol_spinenet_playground.py b/codes/scripts/byol_spinenet_playground.py index 498fa1c2..bb74a6bb 100644 --- a/codes/scripts/byol_spinenet_playground.py +++ b/codes/scripts/byol_spinenet_playground.py @@ -4,8 +4,11 @@ import shutil import torch import torch.nn as nn import torchvision +from PIL import Image from torch.utils.data import DataLoader +from torchvision.transforms import ToTensor, Resize from tqdm import tqdm +import numpy as np from data.image_folder_dataset import ImageFolderDataset from models.archs.spinenet_arch import SpineNet @@ -15,12 +18,20 @@ from models.archs.spinenet_arch import SpineNet # and the distance is computed across the channel dimension. def structural_euc_dist(x, y): diff = torch.square(x - y) - sum = torch.sum(diff, dim=1) + sum = torch.sum(diff, dim=-1) return torch.sqrt(sum) def cosine_similarity(x, y): - return nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. + x = norm(x) + y = norm(y) + return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. + + +def norm(x): + sh = x.shape + sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))]) + return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r) def im_norm(x): @@ -30,14 +41,15 @@ def im_norm(x): def get_image_folder_dataloader(batch_size, num_workers): dataset_opt = { 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], 'weights': [1], 'target_size': 512, 'force_multiple': 32, 'scale': 1 } dataset = ImageFolderDataset(dataset_opt) - return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) def create_latent_database(model): @@ -48,21 +60,121 @@ def create_latent_database(model): os.makedirs(output_path, exist_ok=True) dataloader = get_image_folder_dataloader(batch_size, num_workers) id = 0 + dict_count = 1 latent_dict = {} + all_paths = [] for batch in tqdm(dataloader): hq = batch['hq'].to('cuda:1') latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. for b in range(latent.shape[0]): - shutil.copy(batch[b]['HQ_path'], os.path.join(output_path, "%i.jpg" % (id,))) + im_path = batch['HQ_path'][b] + all_paths.append(im_path) latent_dict[id] = latent[b].detach().cpu() - if id % 100 == 0: + if (id+1) % 1000 == 0: print("Saving checkpoint..") - torch.save(latent_dict, "latent_dict.pth") + torch.save(latent_dict, os.path.join(output_path, "latent_dict_%i.pth" % (dict_count,))) + latent_dict = {} + torch.save(all_paths, os.path.join(output_path, "all_paths.pth")) + dict_count += 1 id += 1 +def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size): + _, c, h, w = latent.shape + lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name)) + comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1) + cbl_shape = comparables.shape[:3] + assert cbl_shape[1] == 32 + comparables = comparables.reshape(-1, c) + + clat = latent.reshape(1,-1,h*w).permute(2,0,1) + cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size) + assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case. + mins = [] + min_offsets = [] + for cpbl_chunk in tqdm(cpbl_chunked): + cpbl_chunk = cpbl_chunk.to('cuda:1') + dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0)) + _min = torch.min(dist, dim=-1) + mins.append(_min[0]) + min_offsets.append(_min[1]) + mins = torch.min(torch.stack(mins, dim=-1), dim=-1) + # There's some way to do this in torch, I just can't figure it out.. + for i in range(len(mins[1])): + mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i] + + return mins[0].cpu(), mins[1].cpu(), len(comparables) + + +def find_similar_latents(model): + img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg' + #img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg' + hq_img_repo = '../../results/byol_spinenet_latents' + output_path = '../../results/byol_spinenet_similars' + batch_size = 1024 + num_maps = 8 + + os.makedirs(output_path, exist_ok=True) + img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth")) + img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0) + _, _, h, w = img_t.shape + img_t = img_t[:, :, :128*(h//128), :128*(w//128)] + + latent = model(img_t)[1] + _, c, h, w = latent.shape + mins, min_offsets = [], [] + total_latents = -1 + for d_id in range(1,num_maps+1): + mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size) + if total_latents != -1: + assert total_latents == tl + else: + total_latents = tl + mins.append(mn) + min_offsets.append(of) + mins = torch.min(torch.stack(mins, dim=-1), dim=-1) + # There's some way to do this in torch, I just can't figure it out.. + for i in range(len(mins[1])): + mins[1][i] = mins[1][i] * total_latents + min_offsets[mins[1][i]][i] + min_ids = mins[1] + + print("Constructing image map..") + doc_out = ''' + + %s
+ + + ''' + img_map_areas = [] + img_out = torch.zeros((1,3,h*16,w*16)) + for i, ind in enumerate(tqdm(min_ids)): + u = np.unravel_index(ind.item(), (num_maps*total_latents//(32*32),32,32)) + h_, w_ = np.unravel_index(i, (h, w)) + + img = ToTensor()(Resize((512, 512))(Image.open(img_bank_paths[u[0]]))) + t = 16 * u[1] + l = 16 * u[2] + patch = img[:, t:t+16, l:l+16] + img_out[:,:,h_*16:h_*16+16,w_*16:w_*16+16] = patch + + # Also save the image with a masked map + mask = torch.full_like(img, fill_value=.3) + mask[:, t:t+16, l:l+16] = 1 + masked_img = img * mask + masked_src_img_output_file = os.path.join(output_path, "%i_%i__%i.png" % (t, l, u[0])) + torchvision.utils.save_image(masked_img, masked_src_img_output_file) + + # Update the image map areas. + img_map_areas.append('' % (w_*16,h_*16,w_*16+16,h_*16+16,masked_src_img_output_file)) + torchvision.utils.save_image(img_out, os.path.join(output_path, "output.png")) + torchvision.utils.save_image(img_t, os.path.join(output_path, "source.png")) + doc_out = doc_out % ('\n'.join(img_map_areas)) + with open(os.path.join(output_path, 'map.html'), 'w') as f: + print(doc_out, file=f) + + def explore_latent_results(model): - batch_size = 8 + batch_size = 16 num_workers = 1 output_path = '../../results/byol_spinenet_explore_latents/' @@ -77,7 +189,7 @@ def explore_latent_results(model): b, c, h, w = latent.shape center = latent[:, :, h//2, w//2].unsqueeze(-1).unsqueeze(-1) centers = center.repeat(1, 1, h, w) - dist = structural_euc_dist(latent, centers).unsqueeze(1) + dist = cosine_similarity(latent, centers).unsqueeze(1) dist = im_norm(dist) torchvision.utils.save_image(dist, os.path.join(output_path, "%i.png" % id)) id += 1 @@ -91,4 +203,4 @@ if __name__ == '__main__': model.eval() with torch.no_grad(): - explore_latent_results(model) \ No newline at end of file + find_similar_latents(model) \ No newline at end of file