DL-Art-School/codes/data/multi_frame_dataset.py

91 lines
3.7 KiB
Python
Raw Normal View History

from data.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset
import numpy as np
import torch
from bisect import bisect_left
import os.path as osp
class MultiFrameDataset(BaseUnsupervisedImageDataset):
def __init__(self, opt):
super(MultiFrameDataset, self).__init__(opt)
self.num_frames = opt['num_frames']
def chunk_name(self, i):
return osp.basename(self.chunks[i].path)
def get_sequential_image_paths_from(self, chunk_index, chunk_offset):
im_name = self.chunk_name(chunk_index)
source_name = im_name[:-12]
frames_needed = self.num_frames - 1
# Search backwards for the frames needed. We are assuming that every video in the dataset has at least frames_needed frames.
search_idx = chunk_index-1
while frames_needed > 0:
if source_name in self.chunk_name(search_idx):
frames_needed -= 1
search_idx -= 1
else:
2020-10-08 18:21:04 +00:00
search_idx += 1
break
# Now build num_frames starting from search_idx.
hqs, refs, masks, centers = [], [], [], []
for i in range(self.num_frames):
h, r, c, m, p = self.chunks[search_idx + i][chunk_offset]
hqs.append(h)
refs.append(r)
masks.append(m)
centers.append(c)
path = p
return hqs, refs, masks, centers, path
def __getitem__(self, item):
chunk_ind = bisect_left(self.starting_indices, item)
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
hqs, refs, masks, centers, path = self.get_sequential_image_paths_from(chunk_ind, item-self.starting_indices[chunk_ind])
hs, hrs, hms, hcs = self.resize_hq(hqs, refs, masks, centers)
ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs)
# Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
if __name__ == '__main__':
opt = {
'name': 'amalgam',
2020-10-08 18:21:04 +00:00
'paths': ['/content/fullvideo_256_tiled_test'],
'weights': [1],
'target_size': 128,
'force_multiple': 32,
'scale': 2,
'eval': False,
2020-09-28 20:26:15 +00:00
'fixed_corruptions': ['jpeg-medium'],
'random_corruptions': [],
'num_corrupts_per_image': 0,
'num_frames': 10
}
ds = MultiFrameDataset(opt)
import os
os.makedirs("debug", exist_ok=True)
2020-10-08 18:21:04 +00:00
for i in [3]:
import random
2020-10-08 18:21:04 +00:00
o = ds[i]
2020-09-28 20:26:15 +00:00
k = 'GT'
v = o[k]
if 'path' not in k and 'center' not in k:
fr, f, h, w = v.shape
for j in range(fr):
import torchvision
2020-10-08 18:21:04 +00:00
base=osp.basename(o["GT_path"])
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base))