diff --git a/README.md b/README.md index ad013b5f..74925f7d 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,11 @@ TBC.. ## User Guide TBC +### Development Environment +If you aren't already using [Pycharm](https://www.jetbrains.com/pycharm/) - now is the time to try it out. This project was built in Pycharm and comes with +an IDEA project for you to get started with. I've done all of my development on this repo in this IDE and lean heavily +on its incredible debugger. It's free. Try it out. You won't be sorry. + ### Dataset Preparation DLAS comes with some Dataset instances that I have created for my own use. Unless you want to use one of the recipes above, you'll need to provide your own. Here is how to add your own Dataset: diff --git a/codes/data/README.md b/codes/data/README.md index 78a65825..be9af646 100644 --- a/codes/data/README.md +++ b/codes/data/README.md @@ -20,12 +20,15 @@ This directory contains several reference datasets which I have used in building 1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this dataset are grouped together with a temporal dimension for working with video data. +1. ImageFolderDataset - Reads raw images from a folder and feeds them into the model. Capable of performing corruptions + on those images like the above. 1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares from the source image and resizing them to the target size recursively until the native resolution is hit. Each recursive step decreases the crop size by a factor of 2. +1. TorchDataset - A wrapper for miscellaneous pytorch datasets (e.g. MNIST, CIFAR, etc) which extracts the images + and reformats them in a way that the DLAS trainer understands. 1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have - generally stopped using this for performance reasons in favor of SingleImageDataset but it is useful for validation - and test so I keep it around. + generally stopped using this for performance reasons and it should be considered deprecated. ## Information about the "chunked" format diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index 6fe2d444..9e121f27 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -323,7 +323,7 @@ class FullImageDataset(data.Dataset): gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) - d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, + d = {'lq': img_LQ, 'hq': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, 'lq_center': lq_center, 'gt_center': gt_center, 'LQ_path': LQ_path, 'GT_path': full_path} return d diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index a613175f..323ddd09 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -9,7 +9,7 @@ from io import BytesIO # options. class ImageCorruptor: def __init__(self, opt): - self.fixed_corruptions = opt['fixed_corruptions'] + self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 if self.num_corrupts == 0: return diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index c2ae96ac..1ee85dab 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -29,7 +29,7 @@ class ImageFolderDataset: self.weights = opt['weights'] # Just scan the given directory for images of standard types. - supported_types = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG', 'gif', 'GIF'] + supported_types = ['jpg', 'jpeg', 'png', 'gif'] self.image_paths = [] for path, weight in zip(self.paths, self.weights): cache_path = os.path.join(path, 'cache.pth') @@ -95,7 +95,7 @@ class ImageFolderDataset: hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() - return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} + return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} if __name__ == '__main__': @@ -118,7 +118,7 @@ if __name__ == '__main__': for i in range(0, len(ds)): o = ds[random.randint(0, len(ds))] #for k, v in o.items(): - k = 'LQ' + k = 'lq' v = o[k] #if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'full' in k: diff --git a/codes/data/multi_frame_dataset.py b/codes/data/multi_frame_dataset.py index c82355c9..ce99b348 100644 --- a/codes/data/multi_frame_dataset.py +++ b/codes/data/multi_frame_dataset.py @@ -57,7 +57,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) - return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} @@ -83,7 +83,7 @@ if __name__ == '__main__': batch = None for i in range(len(ds)): import random - k = 'LQ' + k = 'lq' element = ds[random.randint(0,len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index f93bcb06..6673f6e0 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -87,7 +87,7 @@ class MultiScaleDataset(data.Dataset): patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted] patches_lq = torch.stack(patches_lq, dim=0) - d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path} + d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path} return d def __len__(self): diff --git a/codes/data/paired_frame_dataset.py b/codes/data/paired_frame_dataset.py index 801d4182..49328abc 100644 --- a/codes/data/paired_frame_dataset.py +++ b/codes/data/paired_frame_dataset.py @@ -42,7 +42,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) - return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} @@ -68,7 +68,7 @@ if __name__ == '__main__': batch = None for i in range(len(ds)): import random - k = 'LQ' + k = 'lq' element = ds[random.randint(0,len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) diff --git a/codes/data/single_image_dataset.py b/codes/data/single_image_dataset.py index 2521f792..4048f197 100644 --- a/codes/data/single_image_dataset.py +++ b/codes/data/single_image_dataset.py @@ -36,7 +36,7 @@ class SingleImageDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0) lq_ref = torch.cat([lq_ref, lq_mask], dim=0) - return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long), 'LQ_path': path, 'GT_path': path} @@ -62,7 +62,7 @@ if __name__ == '__main__': for i in range(0, len(ds)): o = ds[random.randint(0, len(ds))] #for k, v in o.items(): - k = 'LQ' + k = 'lq' v = o[k] #if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'full' in k: diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py index 2424f686..b00f591d 100644 --- a/codes/data/stylegan2_dataset.py +++ b/codes/data/stylegan2_dataset.py @@ -98,4 +98,4 @@ class Stylegan2Dataset(data.Dataset): path = self.paths[index] img = Image.open(path) img = self.transform(img) - return {'LQ': img, 'GT': img, 'GT_path': str(path)} + return {'lq': img, 'hq': img, 'GT_path': str(path)} diff --git a/codes/data/torch_dataset.py b/codes/data/torch_dataset.py index 58ed34de..01875bfb 100644 --- a/codes/data/torch_dataset.py +++ b/codes/data/torch_dataset.py @@ -24,7 +24,7 @@ class TorchDataset(Dataset): def __getitem__(self, item): underlying_item = self.dataset[item][0] - return {'LQ': underlying_item, 'GT': underlying_item, + return {'lq': underlying_item, 'hq': underlying_item, 'LQ_path': str(item), 'GT_path': str(item)} def __len__(self): diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 451b5e1d..2f9c147b 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -4,7 +4,6 @@ import os import torch from torch.nn.parallel import DataParallel import torch.nn as nn -from apex.parallel import DistributedDataParallel import models.lr_scheduler as lr_scheduler import models.networks as networks @@ -106,6 +105,8 @@ class ExtensibleTrainer(BaseModel): all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] for anet in all_networks: if opt['dist']: + # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. + from apex.parallel import DistributedDataParallel dnet = DistributedDataParallel(anet, delay_allreduce=True) else: dnet = DataParallel(anet, device_ids=opt['gpu_ids']) @@ -160,18 +161,9 @@ class ExtensibleTrainer(BaseModel): o.zero_grad() torch.cuda.empty_cache() - self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)] - if need_GT: - self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] - self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)] - else: - self.hq = self.lq - self.ref = self.lq - - self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} + self.dstate = {} for k, v in data.items(): - if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor): self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)] def optimize_parameters(self, step): @@ -328,8 +320,8 @@ class ExtensibleTrainer(BaseModel): def get_current_visuals(self, need_GT=True): # Conforms to an archaic format from MMSR. - return {'LQ': self.eval_state['lq'][0].float().cpu(), - 'GT': self.eval_state['hq'][0].float().cpu(), + return {'lq': self.eval_state['lq'][0].float().cpu(), + 'hq': self.eval_state['hq'][0].float().cpu(), 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} def print_network(self): diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index 85fb71dd..3e4924f5 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -209,7 +209,7 @@ class SpineNet(nn.Module): def __init__(self, arch, in_channels=3, - output_level=[3, 4, 5, 6, 7], + output_level=[3, 4], conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), zero_init_residual=True, diff --git a/codes/models/eval/flow_gaussian_nll.py b/codes/models/eval/flow_gaussian_nll.py index eed85622..31e7297e 100644 --- a/codes/models/eval/flow_gaussian_nll.py +++ b/codes/models/eval/flow_gaussian_nll.py @@ -30,8 +30,8 @@ class FlowGaussianNll(evaluator.Evaluator): print("Evaluating FlowGaussianNll..") for batch in tqdm(self.dataloader): dev = self.env['device'] - z, _, _ = self.model(gt=batch['GT'].to(dev), - lr=batch['LQ'].to(dev), + z, _, _ = self.model(gt=batch['hq'].to(dev), + lr=batch['lq'].to(dev), epses=[], reverse=False, add_gt_noise=False) diff --git a/codes/models/eval/sr_style.py b/codes/models/eval/sr_style.py index 45ca70ee..b44d5dcf 100644 --- a/codes/models/eval/sr_style.py +++ b/codes/models/eval/sr_style.py @@ -39,7 +39,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator): counter = 0 for batch in self.sampler: noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) - batch_hq = [e['GT'] for e in batch] + batch_hq = [e['hq'] for e in batch] batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device']) resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area") embedding = embedding_generator(resized_batch) diff --git a/codes/models/feature_model.py b/codes/models/feature_model.py index dc9d9dd5..74d82e23 100644 --- a/codes/models/feature_model.py +++ b/codes/models/feature_model.py @@ -66,9 +66,9 @@ class FeatureModel(BaseModel): self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): - self.var_L = data['LQ'].to(self.device) # LQ + self.var_L = data['lq'].to(self.device) # LQ if need_GT: - self.real_H = data['GT'].to(self.device) # GT + self.real_H = data['hq'].to(self.device) # GT def optimize_parameters(self, step): self.optimizer_G.zero_grad() diff --git a/codes/process_video.py b/codes/process_video.py index fd376966..7a53c6e1 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -88,7 +88,7 @@ class FfmpegBackedVideoDataset(data.Dataset): img_LQ = lq_template ref = ref_template - return {'LQ': img_LQ, 'lq_fullsize_ref': ref, + return {'lq': img_LQ, 'lq_fullsize_ref': ref, 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } def __len__(self): @@ -159,8 +159,8 @@ if __name__ == "__main__": need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True if recurrent_mode and first_frame: - b, c, h, w = data['LQ'].shape - recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device) + b, c, h, w = data['lq'].shape + recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device) # Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of. if 'recurrent_hr_generator' in opt.keys(): recurrent_gen = model.env['generators']['generator'] diff --git a/codes/scripts/byol_spinenet_playground.py b/codes/scripts/byol_spinenet_playground.py new file mode 100644 index 00000000..bb74a6bb --- /dev/null +++ b/codes/scripts/byol_spinenet_playground.py @@ -0,0 +1,206 @@ +import os +import shutil + +import torch +import torch.nn as nn +import torchvision +from PIL import Image +from torch.utils.data import DataLoader +from torchvision.transforms import ToTensor, Resize +from tqdm import tqdm +import numpy as np + +from data.image_folder_dataset import ImageFolderDataset +from models.archs.spinenet_arch import SpineNet + + +# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved +# and the distance is computed across the channel dimension. +def structural_euc_dist(x, y): + diff = torch.square(x - y) + sum = torch.sum(diff, dim=-1) + return torch.sqrt(sum) + + +def cosine_similarity(x, y): + x = norm(x) + y = norm(y) + return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself. + + +def norm(x): + sh = x.shape + sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))]) + return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r) + + +def im_norm(x): + return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 + + +def get_image_folder_dataloader(batch_size, num_workers): + dataset_opt = { + 'name': 'amalgam', + #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], + 'weights': [1], + 'target_size': 512, + 'force_multiple': 32, + 'scale': 1 + } + dataset = ImageFolderDataset(dataset_opt) + return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) + + +def create_latent_database(model): + batch_size = 8 + num_workers = 1 + output_path = '../../results/byol_spinenet_latents/' + + os.makedirs(output_path, exist_ok=True) + dataloader = get_image_folder_dataloader(batch_size, num_workers) + id = 0 + dict_count = 1 + latent_dict = {} + all_paths = [] + for batch in tqdm(dataloader): + hq = batch['hq'].to('cuda:1') + latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. + for b in range(latent.shape[0]): + im_path = batch['HQ_path'][b] + all_paths.append(im_path) + latent_dict[id] = latent[b].detach().cpu() + if (id+1) % 1000 == 0: + print("Saving checkpoint..") + torch.save(latent_dict, os.path.join(output_path, "latent_dict_%i.pth" % (dict_count,))) + latent_dict = {} + torch.save(all_paths, os.path.join(output_path, "all_paths.pth")) + dict_count += 1 + id += 1 + + +def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size): + _, c, h, w = latent.shape + lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name)) + comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1) + cbl_shape = comparables.shape[:3] + assert cbl_shape[1] == 32 + comparables = comparables.reshape(-1, c) + + clat = latent.reshape(1,-1,h*w).permute(2,0,1) + cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size) + assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case. + mins = [] + min_offsets = [] + for cpbl_chunk in tqdm(cpbl_chunked): + cpbl_chunk = cpbl_chunk.to('cuda:1') + dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0)) + _min = torch.min(dist, dim=-1) + mins.append(_min[0]) + min_offsets.append(_min[1]) + mins = torch.min(torch.stack(mins, dim=-1), dim=-1) + # There's some way to do this in torch, I just can't figure it out.. + for i in range(len(mins[1])): + mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i] + + return mins[0].cpu(), mins[1].cpu(), len(comparables) + + +def find_similar_latents(model): + img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg' + #img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg' + hq_img_repo = '../../results/byol_spinenet_latents' + output_path = '../../results/byol_spinenet_similars' + batch_size = 1024 + num_maps = 8 + + os.makedirs(output_path, exist_ok=True) + img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth")) + img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0) + _, _, h, w = img_t.shape + img_t = img_t[:, :, :128*(h//128), :128*(w//128)] + + latent = model(img_t)[1] + _, c, h, w = latent.shape + mins, min_offsets = [], [] + total_latents = -1 + for d_id in range(1,num_maps+1): + mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size) + if total_latents != -1: + assert total_latents == tl + else: + total_latents = tl + mins.append(mn) + min_offsets.append(of) + mins = torch.min(torch.stack(mins, dim=-1), dim=-1) + # There's some way to do this in torch, I just can't figure it out.. + for i in range(len(mins[1])): + mins[1][i] = mins[1][i] * total_latents + min_offsets[mins[1][i]][i] + min_ids = mins[1] + + print("Constructing image map..") + doc_out = ''' +