forked from mrq/DL-Art-School
latent work
This commit is contained in:
parent
34d319585c
commit
0cf52ef52c
|
@ -224,8 +224,9 @@ class RRDBNetWithLatent(nn.Module):
|
||||||
# Based heavily on the same VGG arch used for the discriminator.
|
# Based heavily on the same VGG arch used for the discriminator.
|
||||||
class LatentEstimator(nn.Module):
|
class LatentEstimator(nn.Module):
|
||||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||||
def __init__(self, in_nc, nf):
|
def __init__(self, in_nc, nf, overwrite_levels=[]):
|
||||||
super(LatentEstimator, self).__init__()
|
super(LatentEstimator, self).__init__()
|
||||||
|
self.overwrite_levels = overwrite_levels
|
||||||
# [64, 128, 128]
|
# [64, 128, 128]
|
||||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||||
|
@ -276,6 +277,8 @@ class LatentEstimator(nn.Module):
|
||||||
fea = self.lrelu(self.conv0_0(x))
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||||
out = list(checkpoint(self.compute_body, fea))
|
out = list(checkpoint(self.compute_body, fea))
|
||||||
|
for lvl in self.overwrite_levels:
|
||||||
|
out[lvl] = torch.zeros_like(out[lvl])
|
||||||
self.latent_mean = torch.mean(out[-1])
|
self.latent_mean = torch.mean(out[-1])
|
||||||
self.latent_std = torch.std(out[-1])
|
self.latent_std = torch.std(out[-1])
|
||||||
self.latent_var = torch.var(out[-1])
|
self.latent_var = torch.var(out[-1])
|
||||||
|
|
|
@ -124,7 +124,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
||||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
|
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
|
||||||
elif which_model == "latent_estimator":
|
elif which_model == "latent_estimator":
|
||||||
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'])
|
overwrite = [1,2] if opt_net['only_base_level'] else []
|
||||||
|
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
|
|
@ -265,7 +265,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_mi1_rrdb_6bypass.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -382,3 +382,13 @@ def recursively_detach(v):
|
||||||
for k, t in v.items():
|
for k, t in v.items():
|
||||||
out[k] = recursively_detach(t)
|
out[k] = recursively_detach(t)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def opt_get(opt, keys, default=None):
|
||||||
|
if opt is None:
|
||||||
|
return default
|
||||||
|
ret = opt
|
||||||
|
for k in keys:
|
||||||
|
ret = ret.get(k, None)
|
||||||
|
if ret is None:
|
||||||
|
return default
|
||||||
|
return ret
|
||||||
|
|
Loading…
Reference in New Issue
Block a user