Disable refs and centers altogether in single_image_dataset
I suspect that this might be a cause of failures on parallel datasets. Plus it is unnecessary computation.
This commit is contained in:
parent
8f0984cacf
commit
1de1fa30ac
|
@ -4,11 +4,15 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
|
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
class ChunkWithReference:
|
class ChunkWithReference:
|
||||||
def __init__(self, opt, path):
|
def __init__(self, opt, path):
|
||||||
self.path = path.path
|
self.path = path.path
|
||||||
self.tiles, _ = util.get_image_paths('img', self.path)
|
self.tiles, _ = util.get_image_paths('img', self.path)
|
||||||
self.strict = opt['strict'] if 'strict' in opt.keys() else True
|
self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False)
|
||||||
|
self.need_ref = opt_get(opt, ['need_ref'], False)
|
||||||
if 'ignore_first' in opt.keys():
|
if 'ignore_first' in opt.keys():
|
||||||
self.tiles = self.tiles[opt['ignore_first']:]
|
self.tiles = self.tiles[opt['ignore_first']:]
|
||||||
|
|
||||||
|
@ -21,18 +25,19 @@ class ChunkWithReference:
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
tile = self.read_image_or_get_zero(self.tiles[item])
|
tile = self.read_image_or_get_zero(self.tiles[item])
|
||||||
if osp.exists(osp.join(self.path, "ref.jpg")):
|
if self.need_ref and osp.exists(osp.join(self.path, "ref.jpg")):
|
||||||
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
||||||
centers = torch.load(osp.join(self.path, "centers.pt"))
|
|
||||||
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
||||||
if tile_id in centers.keys():
|
if self.need_metadata:
|
||||||
center, tile_width = centers[tile_id]
|
centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||||
elif self.strict:
|
if tile_id in centers.keys():
|
||||||
print("Could not find the given tile id in the accompanying centers.pt. This generally means that "
|
center, tile_width = centers[tile_id]
|
||||||
"centers.pt was overwritten at some point e.g. by duplicate data. If you don't care about tile "
|
else:
|
||||||
"centers, consider passing strict=false to the dataset options. (Note: you must re-build your"
|
print("Could not find the given tile id in the accompanying centers.pt. This generally means that "
|
||||||
"caches for this setting change to take effect.)")
|
"centers.pt was overwritten at some point e.g. by duplicate data. If you don't care about tile "
|
||||||
raise FileNotFoundError(tile_id, self.tiles[item])
|
"centers, consider passing strict=false to the dataset options. (Note: you must re-build your"
|
||||||
|
"caches for this setting change to take effect.)")
|
||||||
|
raise FileNotFoundError(tile_id, self.tiles[item])
|
||||||
else:
|
else:
|
||||||
center = torch.tensor([128, 128], dtype=torch.long)
|
center = torch.tensor([128, 128], dtype=torch.long)
|
||||||
tile_width = 256
|
tile_width = 256
|
||||||
|
|
|
@ -119,7 +119,7 @@ class ImageCorruptor:
|
||||||
raise NotImplementedError("specified jpeg corruption doesn't exist")
|
raise NotImplementedError("specified jpeg corruption doesn't exist")
|
||||||
# JPEG compression
|
# JPEG compression
|
||||||
qf = (rand_int % range + lo)
|
qf = (rand_int % range + lo)
|
||||||
# cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead.
|
# Use PIL to perform a mock compression to a data buffer, then swap back to cv2.
|
||||||
img = (img * 255).astype(np.uint8)
|
img = (img * 255).astype(np.uint8)
|
||||||
img = Image.fromarray(img)
|
img = Image.fromarray(img)
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
|
|
|
@ -50,7 +50,6 @@ if __name__ == '__main__':
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
'scale': 2,
|
'scale': 2,
|
||||||
'eval': False,
|
'eval': False,
|
||||||
'strict': False,
|
|
||||||
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
||||||
'random_corruptions': ['noise-5', 'none'],
|
'random_corruptions': ['noise-5', 'none'],
|
||||||
'num_corrupts_per_image': 1,
|
'num_corrupts_per_image': 1,
|
||||||
|
|
|
@ -293,7 +293,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_23bl_opt.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_xx_faces_glean.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user