Fix for RRDB scale
This commit is contained in:
parent
71fa532356
commit
929cd45c05
|
@ -49,7 +49,7 @@ def define_G(opt, opt_net, scale=None):
|
||||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||||
output_mode=output_mode, body_block=block)
|
output_mode=output_mode, body_block=block, scale=opt_net['scale'])
|
||||||
elif which_model == 'rcan':
|
elif which_model == 'rcan':
|
||||||
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
||||||
opt_net['rgb_range'] = 255
|
opt_net['rgb_range'] = 255
|
||||||
|
@ -147,7 +147,7 @@ def define_G(opt, opt_net, scale=None):
|
||||||
hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale'])
|
hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale'])
|
||||||
elif which_model == 'srflow_orig':
|
elif which_model == 'srflow_orig':
|
||||||
from models.archs.srflow_orig import SRFlowNet_arch
|
from models.archs.srflow_orig import SRFlowNet_arch
|
||||||
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
|
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||||
K=opt_net['K'], opt=opt)
|
K=opt_net['K'], opt=opt)
|
||||||
elif which_model == 'rrdb_latent_wrapper':
|
elif which_model == 'rrdb_latent_wrapper':
|
||||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
|
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_lambda.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user