Enable chunk_with_reference to work without centers
Moving away from this so it doesn't matter too much. Also fixes an issue with the "ignore" flag.
This commit is contained in:
parent
b45e132a9d
commit
c7f3fc4dd9
|
@ -8,6 +8,7 @@ class ChunkWithReference:
|
||||||
def __init__(self, opt, path):
|
def __init__(self, opt, path):
|
||||||
self.path = path.path
|
self.path = path.path
|
||||||
self.tiles, _ = util.get_image_paths('img', self.path)
|
self.tiles, _ = util.get_image_paths('img', self.path)
|
||||||
|
self.strict = opt['strict'] if 'strict' in opt.keys() else True
|
||||||
if 'ignore_first' in opt.keys():
|
if 'ignore_first' in opt.keys():
|
||||||
self.ignore = opt['ignore_first']
|
self.ignore = opt['ignore_first']
|
||||||
self.tiles = self.tiles[self.ignore:]
|
self.tiles = self.tiles[self.ignore:]
|
||||||
|
@ -22,11 +23,17 @@ class ChunkWithReference:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
centers = torch.load(osp.join(self.path, "centers.pt"))[self.ignore:]
|
centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||||
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
||||||
tile = self.read_image_or_get_zero(self.tiles[item])
|
tile = self.read_image_or_get_zero(self.tiles[item])
|
||||||
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
||||||
|
if tile_id in centers.keys():
|
||||||
center, tile_width = centers[tile_id]
|
center, tile_width = centers[tile_id]
|
||||||
|
elif self.strict:
|
||||||
|
raise FileNotFoundError(tile_id, self.tiles[item])
|
||||||
|
else:
|
||||||
|
center = torch.tensor([128,128], dtype=torch.long)
|
||||||
|
tile_width = 256
|
||||||
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
||||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,6 @@ import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import data.util as data_util # noqa: E402
|
import data.util as data_util # noqa: E402
|
||||||
import lmdb
|
|
||||||
import pyarrow
|
|
||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
@ -16,7 +14,7 @@ def main():
|
||||||
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
|
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
|
||||||
split_img = False
|
split_img = False
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 0
|
opt['n_thread'] = 2
|
||||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
@ -244,7 +242,7 @@ class TiledDataset(data.Dataset):
|
||||||
|
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||||
if min(h,w) < 1024:
|
if min(h,w) < 512:
|
||||||
return None
|
return None
|
||||||
left = 0
|
left = 0
|
||||||
right = w
|
right = w
|
||||||
|
|
Loading…
Reference in New Issue
Block a user