Fixes to unified chunk datasets to support stereoscopic training

This commit is contained in:
James Betker 2020-10-26 11:12:22 -06:00
parent 629b968901
commit ff58c6484a
6 changed files with 60 additions and 26 deletions

View File

@ -22,7 +22,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
pin_memory=True) pin_memory=True)
else: else:
batch_size = dataset_opt['batch_size'] or 1 batch_size = dataset_opt['batch_size'] or 1
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=max(int(batch_size/2), 1), return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
pin_memory=True) pin_memory=True)

View File

@ -87,7 +87,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], [] hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], []
for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted): for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted):
h, w = (h - h % hq_multiple), (w - w % hq_multiple) h, w = (h - h % hq_multiple), (w - w % hq_multiple)
hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:1], (h, w))) hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:2], (h, w)))
hqs_conformed.append(hq[:h, :w, :]) hqs_conformed.append(hq[:h, :w, :])
hq_refs_conformed.append(hq_ref[:h, :w, :]) hq_refs_conformed.append(hq_ref[:h, :w, :])
hq_masks_conformed.append(hq_mask[:h, :w, :]) hq_masks_conformed.append(hq_mask[:h, :w, :])

View File

@ -23,19 +23,24 @@ class ChunkWithReference:
return img return img
def __getitem__(self, item): def __getitem__(self, item):
centers = torch.load(osp.join(self.path, "centers.pt"))
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
tile = self.read_image_or_get_zero(self.tiles[item]) tile = self.read_image_or_get_zero(self.tiles[item])
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) if osp.exists(osp.join(self.path, "ref.jpg")):
if tile_id in centers.keys(): tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
center, tile_width = centers[tile_id] centers = torch.load(osp.join(self.path, "centers.pt"))
elif self.strict: ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
raise FileNotFoundError(tile_id, self.tiles[item]) if tile_id in centers.keys():
center, tile_width = centers[tile_id]
elif self.strict:
raise FileNotFoundError(tile_id, self.tiles[item])
else:
center = torch.tensor([128, 128], dtype=torch.long)
tile_width = 256
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
else: else:
center = torch.tensor([128, 128], dtype=torch.long) ref = np.zeros_like(tile)
tile_width = 256 mask = np.zeros(tile.shape[:2] + (1,))
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype) center = (0,0)
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
return tile, ref, center, mask, self.tiles[item] return tile, ref, center, mask, self.tiles[item]

View File

@ -39,7 +39,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
hq_ref = torch.cat([hq_ref, hq_mask], dim=1) hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float() lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
@ -49,9 +49,9 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\paired_images'], 'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\validation'],
'weights': [1], 'weights': [1],
'target_size': 128, #'target_size': 128,
'force_multiple': 32, 'force_multiple': 32,
'scale': 2, 'scale': 2,
'eval': False, 'eval': False,
@ -72,7 +72,7 @@ if __name__ == '__main__':
element = ds[random.randint(0,len(ds))] element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"]) base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0) o = element[k].unsqueeze(0)
if bs < 32: if bs < 2:
if batch is None: if batch is None:
batch = o batch = o
else: else:

View File

@ -1,6 +1,7 @@
import torch import torch
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.flownet2.networks.resample2d_package.resample2d import Resample2d from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.flownet2.utils.flow_utils import flow2img
from models.steps.injectors import Injector from models.steps.injectors import Injector
@ -8,6 +9,8 @@ def create_stereoscopic_injector(opt, env):
type = opt['type'] type = opt['type']
if type == 'stereoscopic_resample': if type == 'stereoscopic_resample':
return ResampleInjector(opt, env) return ResampleInjector(opt, env)
elif type == 'stereoscopic_flow2image':
return Flow2Image(opt, env)
return None return None
@ -19,4 +22,26 @@ class ResampleInjector(Injector):
def forward(self, state): def forward(self, state):
with autocast(enabled=False): with autocast(enabled=False):
return {self.output: self.resample(state[self.input], state[self.flow])} return {self.output: self.resample(state[self.input], state[self.flow])}
# Converts a flowfield to an image representation for viewing purposes.
# Uses flownet's implementation to do so. Which really sucks. TODO: just do my own implementation in the future.
# Note: this is not differentiable and is only usable for debugging purposes.
class Flow2Image(Injector):
def __init__(self, opt, env):
super(Flow2Image, self).__init__(opt, env)
def forward(self, state):
with torch.no_grad():
flo = state[self.input].cpu()
bs, c, h, w = flo.shape
flo = flo.permute(0, 2, 3, 1) # flow2img works in numpy space for some reason..
imgs = torch.empty_like(flo)
flo = flo.numpy()
for b in range(bs):
img = flow2img(flo[b]) # Note that this returns the image in an integer format.
img = torch.tensor(img, dtype=torch.float) / 255
imgs[b] = img
imgs = imgs.permute(0, 3, 1, 2)
return {self.output: imgs}

View File

@ -29,6 +29,8 @@ class Trainer:
def init(self, opt, launcher, all_networks={}): def init(self, opt, launcher, all_networks={}):
self._profile = False self._profile = False
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
#### distributed training settings #### distributed training settings
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1: if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
@ -214,8 +216,8 @@ class Trainer:
val_tqdm = tqdm(self.val_loader) val_tqdm = tqdm(self.val_loader)
for val_data in val_tqdm: for val_data in val_tqdm:
idx += 1 idx += 1
for b in range(len(val_data['LQ_path'])): for b in range(len(val_data['GT_path'])):
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0] img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name) img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir) util.mkdir(img_dir)
@ -226,14 +228,16 @@ class Trainer:
if visuals is None: if visuals is None:
continue continue
# calculate PSNR
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
gt_img = util.tensor2img(visuals['GT'][b]) # uint8 # calculate PSNR
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) if self.val_compute_psnr:
avg_psnr += util.calculate_psnr(sr_img, gt_img) gt_img = util.tensor2img(visuals['GT'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss # calculate fea loss
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
# Save SR images for reference # Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
@ -278,7 +282,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_3dflow_vr_flownet.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)