diff --git a/codes/models/archs/rrdb_with_latent.py b/codes/models/archs/rrdb_with_latent.py index 8951dc10..5b0239a9 100644 --- a/codes/models/archs/rrdb_with_latent.py +++ b/codes/models/archs/rrdb_with_latent.py @@ -224,8 +224,9 @@ class RRDBNetWithLatent(nn.Module): # Based heavily on the same VGG arch used for the discriminator. class LatentEstimator(nn.Module): # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. - def __init__(self, in_nc, nf): + def __init__(self, in_nc, nf, overwrite_levels=[]): super(LatentEstimator, self).__init__() + self.overwrite_levels = overwrite_levels # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) @@ -276,6 +277,8 @@ class LatentEstimator(nn.Module): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) out = list(checkpoint(self.compute_body, fea)) + for lvl in self.overwrite_levels: + out[lvl] = torch.zeros_like(out[lvl]) self.latent_mean = torch.mean(out[-1]) self.latent_std = torch.std(out[-1]) self.latent_var = torch.var(out[-1]) diff --git a/codes/models/networks.py b/codes/models/networks.py index 64f84d8e..12bbf919 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -124,7 +124,8 @@ def define_G(opt, net_key='network_G', scale=None): mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale']) elif which_model == "latent_estimator": - netG = LatentEstimator(in_nc=3, nf=opt_net['nf']) + overwrite = [1,2] if opt_net['only_base_level'] else [] + netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/train.py b/codes/train.py index 120356fe..edae8fdb 100644 --- a/codes/train.py +++ b/codes/train.py @@ -265,7 +265,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_mi1_rrdb_6bypass.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/utils/util.py b/codes/utils/util.py index 581e6307..0f01326f 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -382,3 +382,13 @@ def recursively_detach(v): for k, t in v.items(): out[k] = recursively_detach(t) return out + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret