latent work

This commit is contained in:
James Betker 2020-11-06 20:38:23 -07:00
parent 34d319585c
commit 0cf52ef52c
4 changed files with 17 additions and 3 deletions

View File

@ -224,8 +224,9 @@ class RRDBNetWithLatent(nn.Module):
# Based heavily on the same VGG arch used for the discriminator.
class LatentEstimator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf):
def __init__(self, in_nc, nf, overwrite_levels=[]):
super(LatentEstimator, self).__init__()
self.overwrite_levels = overwrite_levels
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
@ -276,6 +277,8 @@ class LatentEstimator(nn.Module):
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
out = list(checkpoint(self.compute_body, fea))
for lvl in self.overwrite_levels:
out[lvl] = torch.zeros_like(out[lvl])
self.latent_mean = torch.mean(out[-1])
self.latent_std = torch.std(out[-1])
self.latent_var = torch.var(out[-1])

View File

@ -124,7 +124,8 @@ def define_G(opt, net_key='network_G', scale=None):
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
elif which_model == "latent_estimator":
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'])
overwrite = [1,2] if opt_net['only_base_level'] else []
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -265,7 +265,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_mi1_rrdb_6bypass.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -382,3 +382,13 @@ def recursively_detach(v):
for k, t in v.items():
out[k] = recursively_detach(t)
return out
def opt_get(opt, keys, default=None):
if opt is None:
return default
ret = opt
for k in keys:
ret = ret.get(k, None)
if ret is None:
return default
return ret