From 5a27187c59bea5725e4f8eb904165974d8eb3adb Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 25 Sep 2020 22:45:57 -0600 Subject: [PATCH] More mods to accomodate new dataset --- codes/data/chunk_with_reference.py | 2 +- codes/data/image_corruptor.py | 6 +++--- codes/data/single_image_dataset.py | 7 +++++++ codes/data/util.py | 3 ++- codes/models/archs/SPSR_arch.py | 4 ++-- codes/models/archs/SwitchedResidualGenerator_arch.py | 2 +- codes/models/networks.py | 3 +++ 7 files changed, 19 insertions(+), 8 deletions(-) diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index 8704b327..4da75f4d 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -8,7 +8,7 @@ class ChunkWithReference: def __init__(self, opt, path): self.reload(opt) self.path = path.path - self.tiles, _ = util.get_image_paths('img', path) + self.tiles, _ = util.get_image_paths('img', self.path) self.centers = None def reload(self, opt): diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 935d85dc..79d0e47e 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -27,10 +27,10 @@ class ImageCorruptor: corrupted_imgs = [] for img in imgs: - for aug in self.fixed_corruptions: - img = self.apply_corruption(img, aug, rand_int_f) for aug in augmentations: img = self.apply_corruption(img, aug, rand_int_a) + for aug in self.fixed_corruptions: + img = self.apply_corruption(img, aug, rand_int_f) corrupted_imgs.append(img) return corrupted_imgs @@ -81,7 +81,7 @@ class ImageCorruptor: img += np.random.randn() * noise_intensity elif 'jpeg' in aug: # JPEG compression - qf = (rand_int % 20 + 10) # Between 10-30 + qf = (rand_int % 20 + 5) # Between 5-25 # cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead. img = (img * 255).astype(np.uint8) img = Image.fromarray(img) diff --git a/codes/data/single_image_dataset.py b/codes/data/single_image_dataset.py index cdea7f4a..ec148d83 100644 --- a/codes/data/single_image_dataset.py +++ b/codes/data/single_image_dataset.py @@ -39,6 +39,13 @@ class SingleImageDataset(data.Dataset): c.reload(opt) else: chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()] + # Prune out chunks that have no images + res = [] + for c in chunks: + if len(c) != 0: + res.append(c) + chunks = res + # Save to a cache. torch.save(chunks, cache_path) for w in range(weight): self.chunks.extend(chunks) diff --git a/codes/data/util.py b/codes/data/util.py index f3ad9005..7e844d1e 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -28,7 +28,8 @@ def _get_paths_from_images(path): if is_image_file(fname) and 'ref.jpg' not in fname: img_path = os.path.join(dirpath, fname) images.append(img_path) - assert images, '{:s} has no valid image file'.format(path) + if not images: + print("Warning: {:s} has no valid image file".format(path)) return images diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 10aab1f1..ba628274 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -677,14 +677,14 @@ class Spsr4(nn.Module): class Spsr5(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=2, init_temperature=10): super(Spsr5, self).__init__() n_upscale = int(math.log(upscale, 2)) # switch options transformation_filters = nf self.transformation_counts = xforms - multiplx_fn = functools.partial(QueryKeyMultiplexer, transformation_filters) + multiplx_fn = functools.partial(QueryKeyMultiplexer, transformation_filters, reductions=multiplexer_reductions) pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=3, depth=3, diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index ad841931..01034863 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -413,7 +413,7 @@ class BackboneEncoderNoRef(nn.Module): class BackboneSpinenetNoHead(nn.Module): def __init__(self): super(BackboneSpinenetNoHead, self).__init__() - self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True) + self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True, double_reduce_early=False) def forward(self, x): patch = checkpoint(self.patch_spine, x)[0] diff --git a/codes/models/networks.py b/codes/models/networks.py index 60aae360..858ed938 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -64,6 +64,7 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == "spsr5": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.Spsr5(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 2, init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == "ssgr1": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 @@ -81,6 +82,8 @@ def define_G(opt, net_key='network_G', scale=None): netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) elif which_model == "backbone_encoder_no_ref": netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet']) + elif which_model == "backbone_encoder_no_head": + netG = SwitchedGen_arch.BackboneSpinenetNoHead() elif which_model == "backbone_resnet": netG = SwitchedGen_arch.BackboneResnet() else: