From 11155aead4933c6046403e04e3a6a75fa6bbb151 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Dec 2020 20:14:53 -0700 Subject: [PATCH] Directly use dataset keys This has been a long time coming. Cleans up messy "GT" nomenclature and simplifies ExtensibleTraner.feed_data --- codes/data/README.md | 7 +++++-- codes/data/full_image_dataset.py | 2 +- codes/data/image_folder_dataset.py | 4 ++-- codes/data/multi_frame_dataset.py | 4 ++-- codes/data/multiscale_dataset.py | 2 +- codes/data/paired_frame_dataset.py | 4 ++-- codes/data/single_image_dataset.py | 4 ++-- codes/data/stylegan2_dataset.py | 2 +- codes/data/torch_dataset.py | 2 +- codes/models/ExtensibleTrainer.py | 20 ++++++------------- codes/models/eval/flow_gaussian_nll.py | 4 ++-- codes/models/eval/sr_style.py | 2 +- codes/models/feature_model.py | 4 ++-- codes/process_video.py | 6 +++--- .../compute_fdpl_perceptual_weights.py | 4 ++-- codes/scripts/create_lmdb.py | 10 +++++----- .../scripts/srflow_latent_space_playground.py | 15 +++++++++----- codes/scripts/test_dataloader.py | 4 ++-- codes/scripts/use_discriminator_as_filter.py | 2 +- codes/scripts/use_generator_as_filter.py | 4 ++-- codes/test.py | 4 ++-- codes/train.py | 4 ++-- codes/train2.py | 10 ++++++---- 23 files changed, 63 insertions(+), 61 deletions(-) diff --git a/codes/data/README.md b/codes/data/README.md index 78a65825..be9af646 100644 --- a/codes/data/README.md +++ b/codes/data/README.md @@ -20,12 +20,15 @@ This directory contains several reference datasets which I have used in building 1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this dataset are grouped together with a temporal dimension for working with video data. +1. ImageFolderDataset - Reads raw images from a folder and feeds them into the model. Capable of performing corruptions + on those images like the above. 1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares from the source image and resizing them to the target size recursively until the native resolution is hit. Each recursive step decreases the crop size by a factor of 2. +1. TorchDataset - A wrapper for miscellaneous pytorch datasets (e.g. MNIST, CIFAR, etc) which extracts the images + and reformats them in a way that the DLAS trainer understands. 1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have - generally stopped using this for performance reasons in favor of SingleImageDataset but it is useful for validation - and test so I keep it around. + generally stopped using this for performance reasons and it should be considered deprecated. ## Information about the "chunked" format diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index 6fe2d444..9e121f27 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -323,7 +323,7 @@ class FullImageDataset(data.Dataset): gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) - d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, + d = {'lq': img_LQ, 'hq': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, 'lq_center': lq_center, 'gt_center': gt_center, 'LQ_path': LQ_path, 'GT_path': full_path} return d diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index c2ae96ac..31309edc 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -95,7 +95,7 @@ class ImageFolderDataset: hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() - return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} + return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} if __name__ == '__main__': @@ -118,7 +118,7 @@ if __name__ == '__main__': for i in range(0, len(ds)): o = ds[random.randint(0, len(ds))] #for k, v in o.items(): - k = 'LQ' + k = 'lq' v = o[k] #if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'full' in k: diff --git a/codes/data/multi_frame_dataset.py b/codes/data/multi_frame_dataset.py index c82355c9..ce99b348 100644 --- a/codes/data/multi_frame_dataset.py +++ b/codes/data/multi_frame_dataset.py @@ -57,7 +57,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) - return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} @@ -83,7 +83,7 @@ if __name__ == '__main__': batch = None for i in range(len(ds)): import random - k = 'LQ' + k = 'lq' element = ds[random.randint(0,len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index f93bcb06..6673f6e0 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -87,7 +87,7 @@ class MultiScaleDataset(data.Dataset): patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted] patches_lq = torch.stack(patches_lq, dim=0) - d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path} + d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path} return d def __len__(self): diff --git a/codes/data/paired_frame_dataset.py b/codes/data/paired_frame_dataset.py index 801d4182..49328abc 100644 --- a/codes/data/paired_frame_dataset.py +++ b/codes/data/paired_frame_dataset.py @@ -42,7 +42,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) - return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} @@ -68,7 +68,7 @@ if __name__ == '__main__': batch = None for i in range(len(ds)): import random - k = 'LQ' + k = 'lq' element = ds[random.randint(0,len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) diff --git a/codes/data/single_image_dataset.py b/codes/data/single_image_dataset.py index 2521f792..4048f197 100644 --- a/codes/data/single_image_dataset.py +++ b/codes/data/single_image_dataset.py @@ -36,7 +36,7 @@ class SingleImageDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0) lq_ref = torch.cat([lq_ref, lq_mask], dim=0) - return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long), 'LQ_path': path, 'GT_path': path} @@ -62,7 +62,7 @@ if __name__ == '__main__': for i in range(0, len(ds)): o = ds[random.randint(0, len(ds))] #for k, v in o.items(): - k = 'LQ' + k = 'lq' v = o[k] #if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'full' in k: diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py index 2424f686..b00f591d 100644 --- a/codes/data/stylegan2_dataset.py +++ b/codes/data/stylegan2_dataset.py @@ -98,4 +98,4 @@ class Stylegan2Dataset(data.Dataset): path = self.paths[index] img = Image.open(path) img = self.transform(img) - return {'LQ': img, 'GT': img, 'GT_path': str(path)} + return {'lq': img, 'hq': img, 'GT_path': str(path)} diff --git a/codes/data/torch_dataset.py b/codes/data/torch_dataset.py index 58ed34de..01875bfb 100644 --- a/codes/data/torch_dataset.py +++ b/codes/data/torch_dataset.py @@ -24,7 +24,7 @@ class TorchDataset(Dataset): def __getitem__(self, item): underlying_item = self.dataset[item][0] - return {'LQ': underlying_item, 'GT': underlying_item, + return {'lq': underlying_item, 'hq': underlying_item, 'LQ_path': str(item), 'GT_path': str(item)} def __len__(self): diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 451b5e1d..2f9c147b 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -4,7 +4,6 @@ import os import torch from torch.nn.parallel import DataParallel import torch.nn as nn -from apex.parallel import DistributedDataParallel import models.lr_scheduler as lr_scheduler import models.networks as networks @@ -106,6 +105,8 @@ class ExtensibleTrainer(BaseModel): all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] for anet in all_networks: if opt['dist']: + # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. + from apex.parallel import DistributedDataParallel dnet = DistributedDataParallel(anet, delay_allreduce=True) else: dnet = DataParallel(anet, device_ids=opt['gpu_ids']) @@ -160,18 +161,9 @@ class ExtensibleTrainer(BaseModel): o.zero_grad() torch.cuda.empty_cache() - self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)] - if need_GT: - self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] - self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)] - else: - self.hq = self.lq - self.ref = self.lq - - self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} + self.dstate = {} for k, v in data.items(): - if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor): self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)] def optimize_parameters(self, step): @@ -328,8 +320,8 @@ class ExtensibleTrainer(BaseModel): def get_current_visuals(self, need_GT=True): # Conforms to an archaic format from MMSR. - return {'LQ': self.eval_state['lq'][0].float().cpu(), - 'GT': self.eval_state['hq'][0].float().cpu(), + return {'lq': self.eval_state['lq'][0].float().cpu(), + 'hq': self.eval_state['hq'][0].float().cpu(), 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} def print_network(self): diff --git a/codes/models/eval/flow_gaussian_nll.py b/codes/models/eval/flow_gaussian_nll.py index eed85622..31e7297e 100644 --- a/codes/models/eval/flow_gaussian_nll.py +++ b/codes/models/eval/flow_gaussian_nll.py @@ -30,8 +30,8 @@ class FlowGaussianNll(evaluator.Evaluator): print("Evaluating FlowGaussianNll..") for batch in tqdm(self.dataloader): dev = self.env['device'] - z, _, _ = self.model(gt=batch['GT'].to(dev), - lr=batch['LQ'].to(dev), + z, _, _ = self.model(gt=batch['hq'].to(dev), + lr=batch['lq'].to(dev), epses=[], reverse=False, add_gt_noise=False) diff --git a/codes/models/eval/sr_style.py b/codes/models/eval/sr_style.py index 45ca70ee..b44d5dcf 100644 --- a/codes/models/eval/sr_style.py +++ b/codes/models/eval/sr_style.py @@ -39,7 +39,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator): counter = 0 for batch in self.sampler: noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) - batch_hq = [e['GT'] for e in batch] + batch_hq = [e['hq'] for e in batch] batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device']) resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area") embedding = embedding_generator(resized_batch) diff --git a/codes/models/feature_model.py b/codes/models/feature_model.py index dc9d9dd5..74d82e23 100644 --- a/codes/models/feature_model.py +++ b/codes/models/feature_model.py @@ -66,9 +66,9 @@ class FeatureModel(BaseModel): self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): - self.var_L = data['LQ'].to(self.device) # LQ + self.var_L = data['lq'].to(self.device) # LQ if need_GT: - self.real_H = data['GT'].to(self.device) # GT + self.real_H = data['hq'].to(self.device) # GT def optimize_parameters(self, step): self.optimizer_G.zero_grad() diff --git a/codes/process_video.py b/codes/process_video.py index fd376966..7a53c6e1 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -88,7 +88,7 @@ class FfmpegBackedVideoDataset(data.Dataset): img_LQ = lq_template ref = ref_template - return {'LQ': img_LQ, 'lq_fullsize_ref': ref, + return {'lq': img_LQ, 'lq_fullsize_ref': ref, 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } def __len__(self): @@ -159,8 +159,8 @@ if __name__ == "__main__": need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True if recurrent_mode and first_frame: - b, c, h, w = data['LQ'].shape - recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device) + b, c, h, w = data['lq'].shape + recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device) # Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of. if 'recurrent_hr_generator' in opt.keys(): recurrent_gen = model.env['generators']['generator'] diff --git a/codes/scripts/compute_fdpl_perceptual_weights.py b/codes/scripts/compute_fdpl_perceptual_weights.py index 1a6506b2..3b411721 100644 --- a/codes/scripts/compute_fdpl_perceptual_weights.py +++ b/codes/scripts/compute_fdpl_perceptual_weights.py @@ -40,8 +40,8 @@ if __name__ == '__main__': break sampled += 1 - im = rgb2ycbcr(train_data['GT'].double()) - im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(), + im = rgb2ycbcr(train_data['hq'].double()) + im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(), size=im.shape[2:], mode="bicubic", align_corners=False)) patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True) diff --git a/codes/scripts/create_lmdb.py b/codes/scripts/create_lmdb.py index c44b5c78..7b6d5de4 100644 --- a/codes/scripts/create_lmdb.py +++ b/codes/scripts/create_lmdb.py @@ -16,7 +16,7 @@ import utils.util as util # noqa: E402 def main(): dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test - mode = 'GT' # used for vimeo90k and REDS datasets + mode = 'hq' # used for vimeo90k and REDS datasets # vimeo90k: GT | LR | flow # REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp # train_sharp_flowx4 @@ -159,7 +159,7 @@ def vimeo90k(mode): read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False - if mode == 'GT': + if mode == 'hq': img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' @@ -204,7 +204,7 @@ def vimeo90k(mode): keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) all_img_list = sorted(all_img_list) keys = sorted(keys) - if mode == 'GT': # only read the 4th frame for the GT mode + if mode == 'hq': # only read the 4th frame for the GT mode print('Only keep the 4th frame.') all_img_list = [v for v in all_img_list if v.endswith('im4.png')] keys = [v for v in keys if v.endswith('_4')] @@ -255,9 +255,9 @@ def vimeo90k(mode): #### create meta information meta_info = {} - if mode == 'GT': + if mode == 'hq': meta_info['name'] = 'Vimeo90K_train_GT' - elif mode == 'LR': + elif mode == 'lq': meta_info['name'] = 'Vimeo90K_train_LR' elif mode == 'flow': meta_info['name'] = 'Vimeo90K_train_flowx4' diff --git a/codes/scripts/srflow_latent_space_playground.py b/codes/scripts/srflow_latent_space_playground.py index 3fe89577..7dd31902 100644 --- a/codes/scripts/srflow_latent_space_playground.py +++ b/codes/scripts/srflow_latent_space_playground.py @@ -163,7 +163,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True srg_analyze = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt @@ -180,10 +180,10 @@ if __name__ == "__main__": gen = model.networks['generator'] gen.eval() - mode = "temperature" # temperature | restore | latent_transfer | feed_through + mode = "restore" # temperature | restore | latent_transfer | feed_through #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" - #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" - imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" + imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" + #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" scale = 2 resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. temperature = 1 @@ -224,6 +224,11 @@ if __name__ == "__main__": t = image_2_tensor(img_file).to(model.env['device']) if resample_factor != 1: t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic") + # Ensure the input image is a factor of 16. + _, _, h, w = t.shape + h = 16 * (h // 16) + w = 16 * (w // 16) + t = t[:, :, :h, :w] resample_img = t # Fetch the latent metrics & latents for each image we are resampling. @@ -255,6 +260,6 @@ if __name__ == "__main__": for j in range(len(lats)): path = os.path.join(output_path, "%i_%i" % (im_it, j)) os.makedirs(path, exist_ok=True) - torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it))) + torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it))) create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"), path, [torch.zeros_like(l) for l in lats[j]], lats[j]) diff --git a/codes/scripts/test_dataloader.py b/codes/scripts/test_dataloader.py index 73df5242..642c6ef7 100644 --- a/codes/scripts/test_dataloader.py +++ b/codes/scripts/test_dataloader.py @@ -85,8 +85,8 @@ def main(): if dataset == 'REDS' or dataset == 'Vimeo90K': LQs = data['LQs'] else: - LQ = data['LQ'] - GT = data['GT'] + LQ = data['lq'] + GT = data['hq'] if dataset == 'REDS' or dataset == 'Vimeo90K': for j in range(LQs.size(1)): diff --git a/codes/scripts/use_discriminator_as_filter.py b/codes/scripts/use_discriminator_as_filter.py index 94a69eb3..14cb6887 100644 --- a/codes/scripts/use_discriminator_as_filter.py +++ b/codes/scripts/use_discriminator_as_filter.py @@ -68,6 +68,6 @@ if __name__ == "__main__": # removed += 1 imname = osp.basename(data['GT_path'][i]) if results[i]-dataset_mean > 1: - torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) + torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname)) print("Removed %i/%i images" % (removed, len(test_set))) \ No newline at end of file diff --git a/codes/scripts/use_generator_as_filter.py b/codes/scripts/use_generator_as_filter.py index 5f1a4a94..26718702 100644 --- a/codes/scripts/use_generator_as_filter.py +++ b/codes/scripts/use_generator_as_filter.py @@ -66,7 +66,7 @@ if __name__ == "__main__": model.test() gen = model.eval_state['gen'][0].to(model.env['device']) feagen = netF(gen) - feareal = netF(data['GT'].to(model.env['device'])) + feareal = netF(data['hq'].to(model.env['device'])) losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3)) means.append(torch.mean(losses).item()) #print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses)) @@ -76,6 +76,6 @@ if __name__ == "__main__": removed += 1 #imname = osp.basename(data['GT_path'][i]) #if losses[i] < 25000: - # torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) + # torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname)) print("Removed %i/%i images" % (removed, len(test_set))) \ No newline at end of file diff --git a/codes/test.py b/codes/test.py index 71f58c85..584d0dd7 100644 --- a/codes/test.py +++ b/codes/test.py @@ -41,9 +41,9 @@ def forward_pass(model, output_dir, alteration_suffix=''): save_img_path = osp.join(output_dir, img_name + '.png') if need_GT: - fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i]) + fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i]) psnr_sr = util.tensor2img(visuals[i]) - psnr_gt = util.tensor2img(data['GT'][i]) + psnr_gt = util.tensor2img(data['hq'][i]) psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt) util.save_img(sr_img, save_img_path) diff --git a/codes/train.py b/codes/train.py index aa124dd3..08e97474 100644 --- a/codes/train.py +++ b/codes/train.py @@ -231,13 +231,13 @@ class Trainer: sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 # calculate PSNR if self.val_compute_psnr: - gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + gt_img = util.tensor2img(visuals['hq'][b]) # uint8 sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) # calculate fea loss if self.val_compute_fea: - avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) + avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b]) # Save SR images for reference img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) diff --git a/codes/train2.py b/codes/train2.py index 95e8c32d..06e3f0eb 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -3,6 +3,8 @@ import math import argparse import random import logging + +import torchvision from tqdm import tqdm import torch @@ -231,18 +233,18 @@ class Trainer: sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 # calculate PSNR if self.val_compute_psnr: - gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + gt_img = util.tensor2img(visuals['hq'][b]) # uint8 sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) # calculate fea loss if self.val_compute_fea: - avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) + avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b]) # Save SR images for reference img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) save_img_path = os.path.join(img_dir, img_base_name) - util.save_img(sr_img, save_img_path) + torchvision.utils.save_image(visuals['rlt'], save_img_path) avg_psnr = avg_psnr / idx avg_fea_loss = avg_fea_loss / idx @@ -291,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()