Fix vgg_gn input_img_factor

This commit is contained in:
James Betker 2020-08-31 09:50:30 -06:00
parent 4b4d08bdec
commit 0a9b85f239
3 changed files with 4 additions and 4 deletions

View File

@ -108,7 +108,7 @@ class Discriminator_VGG_128_GN(nn.Module):
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
final_nf = nf * 8 final_nf = nf * 8
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = nn.Linear(100, 1) self.linear2 = nn.Linear(100, 1)
# activation function # activation function

View File

@ -160,9 +160,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
which_model = opt_net['which_model_D'] which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128': if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv']) netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
elif which_model == 'discriminator_vgg_128_gn': elif which_model == 'discriminator_vgg_128_gn':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128) netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
if wrap: if wrap:
netD = GradDiscWrapper(netD) netD = GradDiscWrapper(netD)
elif which_model == 'discriminator_resnet': elif which_model == 'discriminator_resnet':

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/srgan_compute_feature.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)