Fix vgg_gn input_img_factor
This commit is contained in:
parent
4b4d08bdec
commit
0a9b85f239
|
@ -108,7 +108,7 @@ class Discriminator_VGG_128_GN(nn.Module):
|
||||||
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
final_nf = nf * 8
|
final_nf = nf * 8
|
||||||
|
|
||||||
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
|
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
|
||||||
self.linear2 = nn.Linear(100, 1)
|
self.linear2 = nn.Linear(100, 1)
|
||||||
|
|
||||||
# activation function
|
# activation function
|
||||||
|
|
|
@ -160,9 +160,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
which_model = opt_net['which_model_D']
|
which_model = opt_net['which_model_D']
|
||||||
|
|
||||||
if which_model == 'discriminator_vgg_128':
|
if which_model == 'discriminator_vgg_128':
|
||||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv'])
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
|
||||||
elif which_model == 'discriminator_vgg_128_gn':
|
elif which_model == 'discriminator_vgg_128_gn':
|
||||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128)
|
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||||
if wrap:
|
if wrap:
|
||||||
netD = GradDiscWrapper(netD)
|
netD = GradDiscWrapper(netD)
|
||||||
elif which_model == 'discriminator_resnet':
|
elif which_model == 'discriminator_resnet':
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/srgan_compute_feature.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user