forked from mrq/DL-Art-School
Add PairedFrameDataset
This commit is contained in:
parent
7a75d10784
commit
8e5b6682bf
|
@ -39,6 +39,8 @@ def create_dataset(dataset_opt):
|
|||
from data.combined_dataset import CombinedDataset as D
|
||||
elif mode == 'multiscale':
|
||||
from data.multiscale_dataset import MultiScaleDataset as D
|
||||
elif mode == 'paired_frame':
|
||||
from data.paired_frame_dataset import PairedFrameDataset as D
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
|
91
codes/data/paired_frame_dataset.py
Normal file
91
codes/data/paired_frame_dataset.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
from data.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset
|
||||
import numpy as np
|
||||
import torch
|
||||
from bisect import bisect_left
|
||||
import os.path as osp
|
||||
|
||||
class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
||||
def __init__(self, opt):
|
||||
super(PairedFrameDataset, self).__init__(opt)
|
||||
|
||||
def get_pair(self, chunk_index, chunk_offset):
|
||||
imname = osp.basename(self.chunks[chunk_index].path)
|
||||
if '_left' in imname:
|
||||
chunks = [chunk_index, chunk_index+1]
|
||||
else:
|
||||
chunks = [chunk_index-1, chunk_index]
|
||||
hqs, refs, masks, centers = [], [], [], []
|
||||
for i in chunks:
|
||||
h, r, c, m, p = self.chunks[i][chunk_offset]
|
||||
hqs.append(h)
|
||||
refs.append(r)
|
||||
masks.append(m)
|
||||
centers.append(c)
|
||||
path = p
|
||||
return hqs, refs, masks, centers, path
|
||||
|
||||
def __getitem__(self, item):
|
||||
chunk_ind = bisect_left(self.starting_indices, item)
|
||||
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||
hqs, refs, masks, centers, path = self.get_pair(chunk_ind, item-self.starting_indices[chunk_ind])
|
||||
|
||||
hs, hrs, hms, hcs = self.resize_hq(hqs, refs, masks, centers)
|
||||
ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs)
|
||||
|
||||
# Convert to torch tensor
|
||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
|
||||
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
|
||||
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||
|
||||
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\paired_images'],
|
||||
'weights': [1],
|
||||
'target_size': 128,
|
||||
'force_multiple': 32,
|
||||
'scale': 2,
|
||||
'eval': False,
|
||||
'fixed_corruptions': ['jpeg-medium'],
|
||||
'random_corruptions': [],
|
||||
'num_corrupts_per_image': 0,
|
||||
'num_frames': 10
|
||||
}
|
||||
|
||||
ds = PairedFrameDataset(opt)
|
||||
import os
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
bs = 0
|
||||
batch = None
|
||||
for i in range(len(ds)):
|
||||
import random
|
||||
k = 'LQ'
|
||||
element = ds[random.randint(0,len(ds))]
|
||||
base_file = osp.basename(element["GT_path"])
|
||||
o = element[k].unsqueeze(0)
|
||||
if bs < 32:
|
||||
if batch is None:
|
||||
batch = o
|
||||
else:
|
||||
batch = torch.cat([batch, o], dim=0)
|
||||
bs += 1
|
||||
continue
|
||||
|
||||
if 'path' not in k and 'center' not in k:
|
||||
b, fr, f, h, w = batch.shape
|
||||
for j in range(fr):
|
||||
import torchvision
|
||||
base=osp.basename(base_file)
|
||||
torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
||||
|
||||
bs = 0
|
||||
batch = None
|
Loading…
Reference in New Issue
Block a user