diff --git a/codes/data/__init__.py b/codes/data/__init__.py index d8c2447f..7778c7cb 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -22,7 +22,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): pin_memory=True) else: batch_size = dataset_opt['batch_size'] or 1 - return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=max(int(batch_size/2), 1), + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True) diff --git a/codes/data/base_unsupervised_image_dataset.py b/codes/data/base_unsupervised_image_dataset.py index 1b0e021b..ec5ee5a6 100644 --- a/codes/data/base_unsupervised_image_dataset.py +++ b/codes/data/base_unsupervised_image_dataset.py @@ -87,7 +87,7 @@ class BaseUnsupervisedImageDataset(data.Dataset): hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], [] for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted): h, w = (h - h % hq_multiple), (w - w % hq_multiple) - hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:1], (h, w))) + hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:2], (h, w))) hqs_conformed.append(hq[:h, :w, :]) hq_refs_conformed.append(hq_ref[:h, :w, :]) hq_masks_conformed.append(hq_mask[:h, :w, :]) diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index bab35f35..07fd8195 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -23,19 +23,24 @@ class ChunkWithReference: return img def __getitem__(self, item): - centers = torch.load(osp.join(self.path, "centers.pt")) - ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg")) tile = self.read_image_or_get_zero(self.tiles[item]) - tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) - if tile_id in centers.keys(): - center, tile_width = centers[tile_id] - elif self.strict: - raise FileNotFoundError(tile_id, self.tiles[item]) + if osp.exists(osp.join(self.path, "ref.jpg")): + tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) + centers = torch.load(osp.join(self.path, "centers.pt")) + ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg")) + if tile_id in centers.keys(): + center, tile_width = centers[tile_id] + elif self.strict: + raise FileNotFoundError(tile_id, self.tiles[item]) + else: + center = torch.tensor([128, 128], dtype=torch.long) + tile_width = 256 + mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype) + mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1 else: - center = torch.tensor([128, 128], dtype=torch.long) - tile_width = 256 - mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype) - mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1 + ref = np.zeros_like(tile) + mask = np.zeros(tile.shape[:2] + (1,)) + center = (0,0) return tile, ref, center, mask, self.tiles[item] diff --git a/codes/data/paired_frame_dataset.py b/codes/data/paired_frame_dataset.py index 16c09ba2..801d4182 100644 --- a/codes/data/paired_frame_dataset.py +++ b/codes/data/paired_frame_dataset.py @@ -39,7 +39,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset): hq_ref = torch.cat([hq_ref, hq_mask], dim=1) lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float() lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float() - lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) + lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, @@ -49,9 +49,9 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset): if __name__ == '__main__': opt = { 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\paired_images'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\validation'], 'weights': [1], - 'target_size': 128, + #'target_size': 128, 'force_multiple': 32, 'scale': 2, 'eval': False, @@ -72,7 +72,7 @@ if __name__ == '__main__': element = ds[random.randint(0,len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) - if bs < 32: + if bs < 2: if batch is None: batch = o else: diff --git a/codes/models/steps/stereoscopic.py b/codes/models/steps/stereoscopic.py index f4f4f32d..ff36be6d 100644 --- a/codes/models/steps/stereoscopic.py +++ b/codes/models/steps/stereoscopic.py @@ -1,6 +1,7 @@ import torch from torch.cuda.amp import autocast from models.flownet2.networks.resample2d_package.resample2d import Resample2d +from models.flownet2.utils.flow_utils import flow2img from models.steps.injectors import Injector @@ -8,6 +9,8 @@ def create_stereoscopic_injector(opt, env): type = opt['type'] if type == 'stereoscopic_resample': return ResampleInjector(opt, env) + elif type == 'stereoscopic_flow2image': + return Flow2Image(opt, env) return None @@ -19,4 +22,26 @@ class ResampleInjector(Injector): def forward(self, state): with autocast(enabled=False): - return {self.output: self.resample(state[self.input], state[self.flow])} \ No newline at end of file + return {self.output: self.resample(state[self.input], state[self.flow])} + + +# Converts a flowfield to an image representation for viewing purposes. +# Uses flownet's implementation to do so. Which really sucks. TODO: just do my own implementation in the future. +# Note: this is not differentiable and is only usable for debugging purposes. +class Flow2Image(Injector): + def __init__(self, opt, env): + super(Flow2Image, self).__init__(opt, env) + + def forward(self, state): + with torch.no_grad(): + flo = state[self.input].cpu() + bs, c, h, w = flo.shape + flo = flo.permute(0, 2, 3, 1) # flow2img works in numpy space for some reason.. + imgs = torch.empty_like(flo) + flo = flo.numpy() + for b in range(bs): + img = flow2img(flo[b]) # Note that this returns the image in an integer format. + img = torch.tensor(img, dtype=torch.float) / 255 + imgs[b] = img + imgs = imgs.permute(0, 3, 1, 2) + return {self.output: imgs} diff --git a/codes/train.py b/codes/train.py index 536e1575..8a4fc14b 100644 --- a/codes/train.py +++ b/codes/train.py @@ -29,6 +29,8 @@ class Trainer: def init(self, opt, launcher, all_networks={}): self._profile = False + self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True + self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True #### distributed training settings if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1: @@ -214,8 +216,8 @@ class Trainer: val_tqdm = tqdm(self.val_loader) for val_data in val_tqdm: idx += 1 - for b in range(len(val_data['LQ_path'])): - img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0] + for b in range(len(val_data['GT_path'])): + img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) @@ -226,14 +228,16 @@ class Trainer: if visuals is None: continue - # calculate PSNR sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 - gt_img = util.tensor2img(visuals['GT'][b]) # uint8 - sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) - avg_psnr += util.calculate_psnr(sr_img, gt_img) + # calculate PSNR + if self.val_compute_psnr: + gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) + avg_psnr += util.calculate_psnr(sr_img, gt_img) # calculate fea loss - avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) + if self.val_compute_fea: + avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) # Save SR images for reference img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) @@ -278,7 +282,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_3dflow_vr_flownet.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)