Disable refs and centers altogether in single_image_dataset

I suspect that this might be a cause of failures on parallel datasets.
Plus it is unnecessary computation.
This commit is contained in:
James Betker 2020-12-31 10:13:24 -07:00
parent 8f0984cacf
commit 1de1fa30ac
4 changed files with 18 additions and 14 deletions

View File

@ -4,11 +4,15 @@ import torch
import numpy as np import numpy as np
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates. # Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
from utils.util import opt_get
class ChunkWithReference: class ChunkWithReference:
def __init__(self, opt, path): def __init__(self, opt, path):
self.path = path.path self.path = path.path
self.tiles, _ = util.get_image_paths('img', self.path) self.tiles, _ = util.get_image_paths('img', self.path)
self.strict = opt['strict'] if 'strict' in opt.keys() else True self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False)
self.need_ref = opt_get(opt, ['need_ref'], False)
if 'ignore_first' in opt.keys(): if 'ignore_first' in opt.keys():
self.tiles = self.tiles[opt['ignore_first']:] self.tiles = self.tiles[opt['ignore_first']:]
@ -21,18 +25,19 @@ class ChunkWithReference:
def __getitem__(self, item): def __getitem__(self, item):
tile = self.read_image_or_get_zero(self.tiles[item]) tile = self.read_image_or_get_zero(self.tiles[item])
if osp.exists(osp.join(self.path, "ref.jpg")): if self.need_ref and osp.exists(osp.join(self.path, "ref.jpg")):
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
centers = torch.load(osp.join(self.path, "centers.pt"))
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg")) ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
if tile_id in centers.keys(): if self.need_metadata:
center, tile_width = centers[tile_id] centers = torch.load(osp.join(self.path, "centers.pt"))
elif self.strict: if tile_id in centers.keys():
print("Could not find the given tile id in the accompanying centers.pt. This generally means that " center, tile_width = centers[tile_id]
"centers.pt was overwritten at some point e.g. by duplicate data. If you don't care about tile " else:
"centers, consider passing strict=false to the dataset options. (Note: you must re-build your" print("Could not find the given tile id in the accompanying centers.pt. This generally means that "
"caches for this setting change to take effect.)") "centers.pt was overwritten at some point e.g. by duplicate data. If you don't care about tile "
raise FileNotFoundError(tile_id, self.tiles[item]) "centers, consider passing strict=false to the dataset options. (Note: you must re-build your"
"caches for this setting change to take effect.)")
raise FileNotFoundError(tile_id, self.tiles[item])
else: else:
center = torch.tensor([128, 128], dtype=torch.long) center = torch.tensor([128, 128], dtype=torch.long)
tile_width = 256 tile_width = 256

View File

@ -119,7 +119,7 @@ class ImageCorruptor:
raise NotImplementedError("specified jpeg corruption doesn't exist") raise NotImplementedError("specified jpeg corruption doesn't exist")
# JPEG compression # JPEG compression
qf = (rand_int % range + lo) qf = (rand_int % range + lo)
# cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead. # Use PIL to perform a mock compression to a data buffer, then swap back to cv2.
img = (img * 255).astype(np.uint8) img = (img * 255).astype(np.uint8)
img = Image.fromarray(img) img = Image.fromarray(img)
buffer = BytesIO() buffer = BytesIO()

View File

@ -50,7 +50,6 @@ if __name__ == '__main__':
'force_multiple': 32, 'force_multiple': 32,
'scale': 2, 'scale': 2,
'eval': False, 'eval': False,
'strict': False,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'], 'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1, 'num_corrupts_per_image': 1,

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_23bl_opt.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_xx_faces_glean.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()