Fix for RRDB scale

This commit is contained in:
James Betker 2020-11-27 21:37:10 -07:00
parent 71fa532356
commit 929cd45c05
2 changed files with 3 additions and 3 deletions

View File

@ -49,7 +49,7 @@ def define_G(opt, opt_net, scale=None):
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode, body_block=block) output_mode=output_mode, body_block=block, scale=opt_net['scale'])
elif which_model == 'rcan': elif which_model == 'rcan':
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
opt_net['rgb_range'] = 255 opt_net['rgb_range'] = 255
@ -147,7 +147,7 @@ def define_G(opt, opt_net, scale=None):
hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale']) hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale'])
elif which_model == 'srflow_orig': elif which_model == 'srflow_orig':
from models.archs.srflow_orig import SRFlowNet_arch from models.archs.srflow_orig import SRFlowNet_arch
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
K=opt_net['K'], opt=opt) K=opt_net['K'], opt=opt)
elif which_model == 'rrdb_latent_wrapper': elif which_model == 'rrdb_latent_wrapper':
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_lambda.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()