forked from mrq/DL-Art-School
Fixes to unified chunk datasets to support stereoscopic training
This commit is contained in:
parent
629b968901
commit
ff58c6484a
|
@ -22,7 +22,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
else:
|
else:
|
||||||
batch_size = dataset_opt['batch_size'] or 1
|
batch_size = dataset_opt['batch_size'] or 1
|
||||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=max(int(batch_size/2), 1),
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], []
|
hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], []
|
||||||
for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted):
|
for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted):
|
||||||
h, w = (h - h % hq_multiple), (w - w % hq_multiple)
|
h, w = (h - h % hq_multiple), (w - w % hq_multiple)
|
||||||
hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:1], (h, w)))
|
hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:2], (h, w)))
|
||||||
hqs_conformed.append(hq[:h, :w, :])
|
hqs_conformed.append(hq[:h, :w, :])
|
||||||
hq_refs_conformed.append(hq_ref[:h, :w, :])
|
hq_refs_conformed.append(hq_ref[:h, :w, :])
|
||||||
hq_masks_conformed.append(hq_mask[:h, :w, :])
|
hq_masks_conformed.append(hq_mask[:h, :w, :])
|
||||||
|
|
|
@ -23,19 +23,24 @@ class ChunkWithReference:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
centers = torch.load(osp.join(self.path, "centers.pt"))
|
|
||||||
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
|
||||||
tile = self.read_image_or_get_zero(self.tiles[item])
|
tile = self.read_image_or_get_zero(self.tiles[item])
|
||||||
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
if osp.exists(osp.join(self.path, "ref.jpg")):
|
||||||
if tile_id in centers.keys():
|
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
||||||
center, tile_width = centers[tile_id]
|
centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||||
elif self.strict:
|
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
||||||
raise FileNotFoundError(tile_id, self.tiles[item])
|
if tile_id in centers.keys():
|
||||||
|
center, tile_width = centers[tile_id]
|
||||||
|
elif self.strict:
|
||||||
|
raise FileNotFoundError(tile_id, self.tiles[item])
|
||||||
|
else:
|
||||||
|
center = torch.tensor([128, 128], dtype=torch.long)
|
||||||
|
tile_width = 256
|
||||||
|
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
||||||
|
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||||
else:
|
else:
|
||||||
center = torch.tensor([128, 128], dtype=torch.long)
|
ref = np.zeros_like(tile)
|
||||||
tile_width = 256
|
mask = np.zeros(tile.shape[:2] + (1,))
|
||||||
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
center = (0,0)
|
||||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
|
||||||
|
|
||||||
return tile, ref, center, mask, self.tiles[item]
|
return tile, ref, center, mask, self.tiles[item]
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
||||||
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
||||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
|
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
|
||||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||||
|
|
||||||
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||||
|
@ -49,9 +49,9 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
opt = {
|
opt = {
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\paired_images'],
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\validation'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 128,
|
#'target_size': 128,
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
'scale': 2,
|
'scale': 2,
|
||||||
'eval': False,
|
'eval': False,
|
||||||
|
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
||||||
element = ds[random.randint(0,len(ds))]
|
element = ds[random.randint(0,len(ds))]
|
||||||
base_file = osp.basename(element["GT_path"])
|
base_file = osp.basename(element["GT_path"])
|
||||||
o = element[k].unsqueeze(0)
|
o = element[k].unsqueeze(0)
|
||||||
if bs < 32:
|
if bs < 2:
|
||||||
if batch is None:
|
if batch is None:
|
||||||
batch = o
|
batch = o
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||||
|
from models.flownet2.utils.flow_utils import flow2img
|
||||||
from models.steps.injectors import Injector
|
from models.steps.injectors import Injector
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,6 +9,8 @@ def create_stereoscopic_injector(opt, env):
|
||||||
type = opt['type']
|
type = opt['type']
|
||||||
if type == 'stereoscopic_resample':
|
if type == 'stereoscopic_resample':
|
||||||
return ResampleInjector(opt, env)
|
return ResampleInjector(opt, env)
|
||||||
|
elif type == 'stereoscopic_flow2image':
|
||||||
|
return Flow2Image(opt, env)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,4 +22,26 @@ class ResampleInjector(Injector):
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
return {self.output: self.resample(state[self.input], state[self.flow])}
|
return {self.output: self.resample(state[self.input], state[self.flow])}
|
||||||
|
|
||||||
|
|
||||||
|
# Converts a flowfield to an image representation for viewing purposes.
|
||||||
|
# Uses flownet's implementation to do so. Which really sucks. TODO: just do my own implementation in the future.
|
||||||
|
# Note: this is not differentiable and is only usable for debugging purposes.
|
||||||
|
class Flow2Image(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(Flow2Image, self).__init__(opt, env)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
with torch.no_grad():
|
||||||
|
flo = state[self.input].cpu()
|
||||||
|
bs, c, h, w = flo.shape
|
||||||
|
flo = flo.permute(0, 2, 3, 1) # flow2img works in numpy space for some reason..
|
||||||
|
imgs = torch.empty_like(flo)
|
||||||
|
flo = flo.numpy()
|
||||||
|
for b in range(bs):
|
||||||
|
img = flow2img(flo[b]) # Note that this returns the image in an integer format.
|
||||||
|
img = torch.tensor(img, dtype=torch.float) / 255
|
||||||
|
imgs[b] = img
|
||||||
|
imgs = imgs.permute(0, 3, 1, 2)
|
||||||
|
return {self.output: imgs}
|
||||||
|
|
|
@ -29,6 +29,8 @@ class Trainer:
|
||||||
|
|
||||||
def init(self, opt, launcher, all_networks={}):
|
def init(self, opt, launcher, all_networks={}):
|
||||||
self._profile = False
|
self._profile = False
|
||||||
|
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||||
|
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||||
|
|
||||||
#### distributed training settings
|
#### distributed training settings
|
||||||
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
||||||
|
@ -214,8 +216,8 @@ class Trainer:
|
||||||
val_tqdm = tqdm(self.val_loader)
|
val_tqdm = tqdm(self.val_loader)
|
||||||
for val_data in val_tqdm:
|
for val_data in val_tqdm:
|
||||||
idx += 1
|
idx += 1
|
||||||
for b in range(len(val_data['LQ_path'])):
|
for b in range(len(val_data['GT_path'])):
|
||||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
|
||||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||||
util.mkdir(img_dir)
|
util.mkdir(img_dir)
|
||||||
|
|
||||||
|
@ -226,14 +228,16 @@ class Trainer:
|
||||||
if visuals is None:
|
if visuals is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# calculate PSNR
|
|
||||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
# calculate PSNR
|
||||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
if self.val_compute_psnr:
|
||||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||||
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
|
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||||
|
|
||||||
# calculate fea loss
|
# calculate fea loss
|
||||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
if self.val_compute_fea:
|
||||||
|
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||||
|
|
||||||
# Save SR images for reference
|
# Save SR images for reference
|
||||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||||
|
@ -278,7 +282,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_3dflow_vr_flownet.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user