More dataset integration work

This commit is contained in:
James Betker 2020-09-25 22:19:38 -06:00
parent 6d0490a0e6
commit 254cb1e915
8 changed files with 43 additions and 24 deletions

View File

@ -38,6 +38,8 @@ def create_dataset(dataset_opt):
from data.Downsample_dataset import DownsampleDataset as D
elif mode == 'fullimage':
from data.full_image_dataset import FullImageDataset as D
elif mode == 'single_image_extensible':
from data.single_image_dataset import SingleImageDataset as D
elif mode == 'combined':
from data.combined_dataset import CombinedDataset as D
else:

View File

@ -6,13 +6,16 @@ import numpy as np
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
class ChunkWithReference:
def __init__(self, opt, path):
self.opt = opt
self.reload(opt)
self.path = path.path
self.ref = None # This is loaded on the fly.
self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else True
self.tiles, _ = util.get_image_paths('img', path)
self.centers = None
def reload(self, opt):
self.opt = opt
self.ref = None # This is loaded on the fly.
self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else False
def __getitem__(self, item):
# Load centers on the fly and always cache.
if self.centers is None:
@ -20,10 +23,9 @@ class ChunkWithReference:
if self.cache_ref:
if self.ref is None:
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
self.centers = torch.load(osp.join(self.path, "centers.pt"))
ref = self.ref
else:
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
tile = util.read_img(None, self.tiles[item], rgb=True)
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
center, tile_width = self.centers[tile_id]

View File

@ -10,10 +10,16 @@ from io import BytesIO
class ImageCorruptor:
def __init__(self, opt):
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2
if self.num_corrupts == 0:
return
self.fixed_corruptions = opt['fixed_corruptions']
self.random_corruptions = opt['random_corruptions']
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
def corrupt_images(self, imgs):
if self.num_corrupts == 0:
return imgs
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
# Source of entropy, which should be used across all images.
rand_int_f = random.randint(1, 999999)
@ -38,11 +44,11 @@ class ImageCorruptor:
img = img / 255
elif 'gaussian_blur' in aug:
# Gaussian Blur
kernel = 2 * (rand_int % 3) + 1
kernel = 2 * self.blur_scale * (rand_int % 3) + 1
img = cv2.GaussianBlur(img, (kernel, kernel), 3)
elif 'motion_blur' in aug:
# Motion blur
intensity = 2 * (rand_int % 3) + 1
intensity = 2 * self.blur_scale * (rand_int % 3) + 1
angle = (rand_int // 3) % 360
k = np.zeros((intensity, intensity), dtype=np.float32)
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
@ -52,7 +58,7 @@ class ImageCorruptor:
img = cv2.filter2D(img, -1, k)
elif 'smooth_blur' in aug:
# Smooth blur
kernel = 2 * (rand_int % 3) + 1
kernel = 2 * self.blur_scale * (rand_int % 3) + 1
img = cv2.blur(img, ksize=(kernel, kernel))
elif 'block_noise' in aug:
# Large distortion blocks in part of an img, such as is used to mask out a face.

View File

@ -14,6 +14,7 @@ import torchvision.transforms.functional as F
class SingleImageDataset(data.Dataset):
def __init__(self, opt):
self.opt = opt
self.corruptor = ImageCorruptor(opt)
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
@ -33,6 +34,9 @@ class SingleImageDataset(data.Dataset):
cache_path = os.path.join(path, 'cache.pth')
if os.path.exists(cache_path):
chunks = torch.load(cache_path)
# Update the options.
for c in chunks:
c.reload(opt)
else:
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
torch.save(chunks, cache_path)
@ -101,7 +105,7 @@ class SingleImageDataset(data.Dataset):
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': lq_center, 'gt_center': hq_center,
'lq_center': torch.tensor(lq_center, dtype=torch.long), 'gt_center': torch.tensor(hq_center, dtype=torch.long),
'LQ_path': path, 'GT_path': path}
def __len__(self):

View File

@ -60,16 +60,20 @@ if __name__ == "__main__":
util.mkdir(dataset_dir)
tq = tqdm(test_loader)
removed = 0
for data in tq:
model.feed_data(data, need_GT=True)
model.test()
results = model.eval_state['discriminator_out'][0]
print(torch.mean(results), torch.max(results), torch.min(results))
for i in range(results.shape[0]):
imname = osp.basename(data['GT_path'][i])
if results[i] < 1:
torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
else:
torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname))
if results[i] < .8:
os.remove(data['GT_path'][i])
removed += 1
#imname = osp.basename(data['GT_path'][i])
#if results[i] > .8:
# torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname))
#else:
# torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
# log
logger.info('# Validation # Fea: {:.4e}'.format(fea_loss / len(test_loader)))
print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -43,9 +43,12 @@ def parse(opt_path, is_train=True):
dataset['mode'] = dataset['mode'].replace('_mc', '')
# path
for key, path in opt['path'].items():
if path and key in opt['path'] and key != 'strict_load':
opt['path'][key] = osp.expanduser(path)
if 'path' in opt.keys():
for key, path in opt['path'].items():
if path and key in opt['path'] and key != 'strict_load':
opt['path'][key] = osp.expanduser(path)
else:
opt['path'] = {}
opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
if is_train:
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])

View File

@ -86,7 +86,6 @@ def forward_pass(model, output_dir, alteration_suffix=''):
if __name__ == "__main__":
#### options
torch.backends.cudnn.benchmark = True
want_just_images = True
srg_analyze = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')

View File

@ -32,9 +32,8 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_ssgr1.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_blur_discriminator.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
@ -83,7 +82,7 @@ def main():
if resume_state is None:
util.mkdir_and_rename(
opt['path']['experiments_root']) # rename experiment folder if exists
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
and 'pretrain_model' not in key and 'resume' not in key))
# config loggers. Before it, the log will not work