forked from mrq/DL-Art-School
Fix vgg_gn input_img_factor
This commit is contained in:
parent
4b4d08bdec
commit
0a9b85f239
|
@ -108,7 +108,7 @@ class Discriminator_VGG_128_GN(nn.Module):
|
|||
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
final_nf = nf * 8
|
||||
|
||||
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
|
||||
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
|
||||
self.linear2 = nn.Linear(100, 1)
|
||||
|
||||
# activation function
|
||||
|
|
|
@ -160,9 +160,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
|||
which_model = opt_net['which_model_D']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv'])
|
||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
|
||||
elif which_model == 'discriminator_vgg_128_gn':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128)
|
||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||
if wrap:
|
||||
netD = GradDiscWrapper(netD)
|
||||
elif which_model == 'discriminator_resnet':
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/srgan_compute_feature.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user