Enable vertical splitting on inference images to support very large resolutions.

This commit is contained in:
James Betker 2020-05-27 08:04:35 -06:00
parent 96ac26a8b7
commit e27a49454e

View File

@ -19,6 +19,7 @@ class LQDataset(data.Dataset):
self.start_at = self.opt['start_at'] self.start_at = self.opt['start_at']
else: else:
self.start_at = 0 self.start_at = 0
self.vertical_splits = self.opt['vertical_splits']
self.paths_LQ, self.paths_GT = None, None self.paths_LQ, self.paths_GT = None, None
self.LQ_env = None # environment for lmdb self.LQ_env = None # environment for lmdb
@ -33,21 +34,29 @@ class LQDataset(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
if self.data_type == 'lmdb' and self.LQ_env is None: if self.data_type == 'lmdb' and self.LQ_env is None:
self._init_lmdb() self._init_lmdb()
actual_index = index # int(index / 2) if self.vertical_splits > 0:
is_left = (index % 2) == 0 actual_index = int(index / self.vertical_splits)
else:
actual_index = index
# get LQ image # get LQ image
LQ_path = self.paths_LQ[actual_index] LQ_path = self.paths_LQ[actual_index]
img_LQ = Image.open(LQ_path) img_LQ = Image.open(LQ_path)
left = 0 if is_left else 1920 if self.vertical_splits > 0:
# crop input if needed. w, h = img_LQ.size
#img_LQ = F.crop(img_LQ, 5, left + 5, 1900, 1900) split_index = (index % self.vertical_splits)
w_per_split = int(w / self.vertical_splits)
left = w_per_split * split_index
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
img_LQ = F.to_tensor(img_LQ) img_LQ = F.to_tensor(img_LQ)
img_name = osp.splitext(osp.basename(LQ_path))[0] img_name = osp.splitext(osp.basename(LQ_path))[0]
LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % 2)) LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % self.vertical_splits))
return {'LQ': img_LQ, 'LQ_path': LQ_path} return {'LQ': img_LQ, 'LQ_path': LQ_path}
def __len__(self): def __len__(self):
return len(self.paths_LQ) # * 2 if self.vertical_splits > 0:
return len(self.paths_LQ) * self.vertical_splits
else:
return len(self.paths_LQ)