forked from mrq/DL-Art-School
More dataset integration work
This commit is contained in:
parent
6d0490a0e6
commit
254cb1e915
|
@ -38,6 +38,8 @@ def create_dataset(dataset_opt):
|
|||
from data.Downsample_dataset import DownsampleDataset as D
|
||||
elif mode == 'fullimage':
|
||||
from data.full_image_dataset import FullImageDataset as D
|
||||
elif mode == 'single_image_extensible':
|
||||
from data.single_image_dataset import SingleImageDataset as D
|
||||
elif mode == 'combined':
|
||||
from data.combined_dataset import CombinedDataset as D
|
||||
else:
|
||||
|
|
|
@ -6,13 +6,16 @@ import numpy as np
|
|||
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
|
||||
class ChunkWithReference:
|
||||
def __init__(self, opt, path):
|
||||
self.opt = opt
|
||||
self.reload(opt)
|
||||
self.path = path.path
|
||||
self.ref = None # This is loaded on the fly.
|
||||
self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else True
|
||||
self.tiles, _ = util.get_image_paths('img', path)
|
||||
self.centers = None
|
||||
|
||||
def reload(self, opt):
|
||||
self.opt = opt
|
||||
self.ref = None # This is loaded on the fly.
|
||||
self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else False
|
||||
|
||||
def __getitem__(self, item):
|
||||
# Load centers on the fly and always cache.
|
||||
if self.centers is None:
|
||||
|
@ -20,10 +23,9 @@ class ChunkWithReference:
|
|||
if self.cache_ref:
|
||||
if self.ref is None:
|
||||
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
|
||||
self.centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||
ref = self.ref
|
||||
else:
|
||||
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
|
||||
ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
|
||||
tile = util.read_img(None, self.tiles[item], rgb=True)
|
||||
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
||||
center, tile_width = self.centers[tile_id]
|
||||
|
|
|
@ -10,10 +10,16 @@ from io import BytesIO
|
|||
class ImageCorruptor:
|
||||
def __init__(self, opt):
|
||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2
|
||||
if self.num_corrupts == 0:
|
||||
return
|
||||
self.fixed_corruptions = opt['fixed_corruptions']
|
||||
self.random_corruptions = opt['random_corruptions']
|
||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||
|
||||
def corrupt_images(self, imgs):
|
||||
if self.num_corrupts == 0:
|
||||
return imgs
|
||||
|
||||
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
|
||||
# Source of entropy, which should be used across all images.
|
||||
rand_int_f = random.randint(1, 999999)
|
||||
|
@ -38,11 +44,11 @@ class ImageCorruptor:
|
|||
img = img / 255
|
||||
elif 'gaussian_blur' in aug:
|
||||
# Gaussian Blur
|
||||
kernel = 2 * (rand_int % 3) + 1
|
||||
kernel = 2 * self.blur_scale * (rand_int % 3) + 1
|
||||
img = cv2.GaussianBlur(img, (kernel, kernel), 3)
|
||||
elif 'motion_blur' in aug:
|
||||
# Motion blur
|
||||
intensity = 2 * (rand_int % 3) + 1
|
||||
intensity = 2 * self.blur_scale * (rand_int % 3) + 1
|
||||
angle = (rand_int // 3) % 360
|
||||
k = np.zeros((intensity, intensity), dtype=np.float32)
|
||||
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
|
||||
|
@ -52,7 +58,7 @@ class ImageCorruptor:
|
|||
img = cv2.filter2D(img, -1, k)
|
||||
elif 'smooth_blur' in aug:
|
||||
# Smooth blur
|
||||
kernel = 2 * (rand_int % 3) + 1
|
||||
kernel = 2 * self.blur_scale * (rand_int % 3) + 1
|
||||
img = cv2.blur(img, ksize=(kernel, kernel))
|
||||
elif 'block_noise' in aug:
|
||||
# Large distortion blocks in part of an img, such as is used to mask out a face.
|
||||
|
|
|
@ -14,6 +14,7 @@ import torchvision.transforms.functional as F
|
|||
class SingleImageDataset(data.Dataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.corruptor = ImageCorruptor(opt)
|
||||
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
|
||||
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
|
||||
|
@ -33,6 +34,9 @@ class SingleImageDataset(data.Dataset):
|
|||
cache_path = os.path.join(path, 'cache.pth')
|
||||
if os.path.exists(cache_path):
|
||||
chunks = torch.load(cache_path)
|
||||
# Update the options.
|
||||
for c in chunks:
|
||||
c.reload(opt)
|
||||
else:
|
||||
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
|
||||
torch.save(chunks, cache_path)
|
||||
|
@ -101,7 +105,7 @@ class SingleImageDataset(data.Dataset):
|
|||
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
|
||||
|
||||
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': lq_center, 'gt_center': hq_center,
|
||||
'lq_center': torch.tensor(lq_center, dtype=torch.long), 'gt_center': torch.tensor(hq_center, dtype=torch.long),
|
||||
'LQ_path': path, 'GT_path': path}
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -60,16 +60,20 @@ if __name__ == "__main__":
|
|||
util.mkdir(dataset_dir)
|
||||
|
||||
tq = tqdm(test_loader)
|
||||
removed = 0
|
||||
for data in tq:
|
||||
model.feed_data(data, need_GT=True)
|
||||
model.test()
|
||||
results = model.eval_state['discriminator_out'][0]
|
||||
print(torch.mean(results), torch.max(results), torch.min(results))
|
||||
for i in range(results.shape[0]):
|
||||
imname = osp.basename(data['GT_path'][i])
|
||||
if results[i] < 1:
|
||||
torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
|
||||
else:
|
||||
torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname))
|
||||
if results[i] < .8:
|
||||
os.remove(data['GT_path'][i])
|
||||
removed += 1
|
||||
#imname = osp.basename(data['GT_path'][i])
|
||||
#if results[i] > .8:
|
||||
# torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname))
|
||||
#else:
|
||||
# torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
|
||||
|
||||
# log
|
||||
logger.info('# Validation # Fea: {:.4e}'.format(fea_loss / len(test_loader)))
|
||||
print("Removed %i/%i images" % (removed, len(test_set)))
|
|
@ -43,9 +43,12 @@ def parse(opt_path, is_train=True):
|
|||
dataset['mode'] = dataset['mode'].replace('_mc', '')
|
||||
|
||||
# path
|
||||
for key, path in opt['path'].items():
|
||||
if path and key in opt['path'] and key != 'strict_load':
|
||||
opt['path'][key] = osp.expanduser(path)
|
||||
if 'path' in opt.keys():
|
||||
for key, path in opt['path'].items():
|
||||
if path and key in opt['path'] and key != 'strict_load':
|
||||
opt['path'][key] = osp.expanduser(path)
|
||||
else:
|
||||
opt['path'] = {}
|
||||
opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
|
||||
if is_train:
|
||||
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
|
||||
|
|
|
@ -86,7 +86,6 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
|||
if __name__ == "__main__":
|
||||
#### options
|
||||
torch.backends.cudnn.benchmark = True
|
||||
want_just_images = True
|
||||
srg_analyze = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
|
||||
|
|
|
@ -32,9 +32,8 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_ssgr1.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_blur_discriminator.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
@ -83,7 +82,7 @@ def main():
|
|||
if resume_state is None:
|
||||
util.mkdir_and_rename(
|
||||
opt['path']['experiments_root']) # rename experiment folder if exists
|
||||
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
|
||||
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
||||
and 'pretrain_model' not in key and 'resume' not in key))
|
||||
|
||||
# config loggers. Before it, the log will not work
|
||||
|
|
Loading…
Reference in New Issue
Block a user