diff --git a/codes/models/networks.py b/codes/models/networks.py index d7f0f1f5..638bcbd1 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -49,7 +49,7 @@ def define_G(opt, opt_net, scale=None): output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, - output_mode=output_mode, body_block=block) + output_mode=output_mode, body_block=block, scale=opt_net['scale']) elif which_model == 'rcan': #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats opt_net['rgb_range'] = 255 @@ -147,7 +147,7 @@ def define_G(opt, opt_net, scale=None): hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale']) elif which_model == 'srflow_orig': from models.archs.srflow_orig import SRFlowNet_arch - netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], + netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], K=opt_net['K'], opt=opt) elif which_model == 'rrdb_latent_wrapper': from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper diff --git a/codes/train2.py b/codes/train2.py index ff16bdc4..95e8c32d 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_lambda.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()