Expose srflow rrdb
This commit is contained in:
parent
f3c1fc1bcd
commit
cb045121b3
|
@ -159,6 +159,10 @@ def define_G(opt, opt_net, scale=None):
|
||||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
|
||||||
headless=True, output_mode=output_mode)
|
headless=True, output_mode=output_mode)
|
||||||
|
elif which_model == 'rrdb_srflow':
|
||||||
|
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||||
|
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_imgsetext_rrdb8x.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user