latent work
This commit is contained in:
parent
34d319585c
commit
0cf52ef52c
|
@ -224,8 +224,9 @@ class RRDBNetWithLatent(nn.Module):
|
|||
# Based heavily on the same VGG arch used for the discriminator.
|
||||
class LatentEstimator(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, in_nc, nf):
|
||||
def __init__(self, in_nc, nf, overwrite_levels=[]):
|
||||
super(LatentEstimator, self).__init__()
|
||||
self.overwrite_levels = overwrite_levels
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
|
@ -276,6 +277,8 @@ class LatentEstimator(nn.Module):
|
|||
fea = self.lrelu(self.conv0_0(x))
|
||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||
out = list(checkpoint(self.compute_body, fea))
|
||||
for lvl in self.overwrite_levels:
|
||||
out[lvl] = torch.zeros_like(out[lvl])
|
||||
self.latent_mean = torch.mean(out[-1])
|
||||
self.latent_std = torch.std(out[-1])
|
||||
self.latent_var = torch.var(out[-1])
|
||||
|
|
|
@ -124,7 +124,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
|
||||
elif which_model == "latent_estimator":
|
||||
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'])
|
||||
overwrite = [1,2] if opt_net['only_base_level'] else []
|
||||
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
|
|
@ -265,7 +265,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_mi1_rrdb_6bypass.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -382,3 +382,13 @@ def recursively_detach(v):
|
|||
for k, t in v.items():
|
||||
out[k] = recursively_detach(t)
|
||||
return out
|
||||
|
||||
def opt_get(opt, keys, default=None):
|
||||
if opt is None:
|
||||
return default
|
||||
ret = opt
|
||||
for k in keys:
|
||||
ret = ret.get(k, None)
|
||||
if ret is None:
|
||||
return default
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue
Block a user