Fixes to unified chunk datasets to support stereoscopic training
This commit is contained in:
parent
629b968901
commit
ff58c6484a
|
@ -22,7 +22,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
|||
pin_memory=True)
|
||||
else:
|
||||
batch_size = dataset_opt['batch_size'] or 1
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=max(int(batch_size/2), 1),
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
|
||||
pin_memory=True)
|
||||
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
|||
hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], []
|
||||
for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted):
|
||||
h, w = (h - h % hq_multiple), (w - w % hq_multiple)
|
||||
hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:1], (h, w)))
|
||||
hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:2], (h, w)))
|
||||
hqs_conformed.append(hq[:h, :w, :])
|
||||
hq_refs_conformed.append(hq_ref[:h, :w, :])
|
||||
hq_masks_conformed.append(hq_mask[:h, :w, :])
|
||||
|
|
|
@ -23,19 +23,24 @@ class ChunkWithReference:
|
|||
return img
|
||||
|
||||
def __getitem__(self, item):
|
||||
centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
||||
tile = self.read_image_or_get_zero(self.tiles[item])
|
||||
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
||||
if tile_id in centers.keys():
|
||||
center, tile_width = centers[tile_id]
|
||||
elif self.strict:
|
||||
raise FileNotFoundError(tile_id, self.tiles[item])
|
||||
if osp.exists(osp.join(self.path, "ref.jpg")):
|
||||
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
||||
centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
||||
if tile_id in centers.keys():
|
||||
center, tile_width = centers[tile_id]
|
||||
elif self.strict:
|
||||
raise FileNotFoundError(tile_id, self.tiles[item])
|
||||
else:
|
||||
center = torch.tensor([128, 128], dtype=torch.long)
|
||||
tile_width = 256
|
||||
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||
else:
|
||||
center = torch.tensor([128, 128], dtype=torch.long)
|
||||
tile_width = 256
|
||||
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||
ref = np.zeros_like(tile)
|
||||
mask = np.zeros(tile.shape[:2] + (1,))
|
||||
center = (0,0)
|
||||
|
||||
return tile, ref, center, mask, self.tiles[item]
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
|||
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||
|
||||
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
|
@ -49,9 +49,9 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
|||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\paired_images'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\validation'],
|
||||
'weights': [1],
|
||||
'target_size': 128,
|
||||
#'target_size': 128,
|
||||
'force_multiple': 32,
|
||||
'scale': 2,
|
||||
'eval': False,
|
||||
|
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
|||
element = ds[random.randint(0,len(ds))]
|
||||
base_file = osp.basename(element["GT_path"])
|
||||
o = element[k].unsqueeze(0)
|
||||
if bs < 32:
|
||||
if bs < 2:
|
||||
if batch is None:
|
||||
batch = o
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from models.flownet2.utils.flow_utils import flow2img
|
||||
from models.steps.injectors import Injector
|
||||
|
||||
|
||||
|
@ -8,6 +9,8 @@ def create_stereoscopic_injector(opt, env):
|
|||
type = opt['type']
|
||||
if type == 'stereoscopic_resample':
|
||||
return ResampleInjector(opt, env)
|
||||
elif type == 'stereoscopic_flow2image':
|
||||
return Flow2Image(opt, env)
|
||||
return None
|
||||
|
||||
|
||||
|
@ -19,4 +22,26 @@ class ResampleInjector(Injector):
|
|||
|
||||
def forward(self, state):
|
||||
with autocast(enabled=False):
|
||||
return {self.output: self.resample(state[self.input], state[self.flow])}
|
||||
return {self.output: self.resample(state[self.input], state[self.flow])}
|
||||
|
||||
|
||||
# Converts a flowfield to an image representation for viewing purposes.
|
||||
# Uses flownet's implementation to do so. Which really sucks. TODO: just do my own implementation in the future.
|
||||
# Note: this is not differentiable and is only usable for debugging purposes.
|
||||
class Flow2Image(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(Flow2Image, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
with torch.no_grad():
|
||||
flo = state[self.input].cpu()
|
||||
bs, c, h, w = flo.shape
|
||||
flo = flo.permute(0, 2, 3, 1) # flow2img works in numpy space for some reason..
|
||||
imgs = torch.empty_like(flo)
|
||||
flo = flo.numpy()
|
||||
for b in range(bs):
|
||||
img = flow2img(flo[b]) # Note that this returns the image in an integer format.
|
||||
img = torch.tensor(img, dtype=torch.float) / 255
|
||||
imgs[b] = img
|
||||
imgs = imgs.permute(0, 3, 1, 2)
|
||||
return {self.output: imgs}
|
||||
|
|
|
@ -29,6 +29,8 @@ class Trainer:
|
|||
|
||||
def init(self, opt, launcher, all_networks={}):
|
||||
self._profile = False
|
||||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||
|
||||
#### distributed training settings
|
||||
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
||||
|
@ -214,8 +216,8 @@ class Trainer:
|
|||
val_tqdm = tqdm(self.val_loader)
|
||||
for val_data in val_tqdm:
|
||||
idx += 1
|
||||
for b in range(len(val_data['LQ_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
||||
for b in range(len(val_data['GT_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
util.mkdir(img_dir)
|
||||
|
||||
|
@ -226,14 +228,16 @@ class Trainer:
|
|||
if visuals is None:
|
||||
continue
|
||||
|
||||
# calculate PSNR
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
# calculate PSNR
|
||||
if self.val_compute_psnr:
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
if self.val_compute_fea:
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||
|
@ -278,7 +282,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_3dflow_vr_flownet.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user