Expose srflow rrdb

This commit is contained in:
James Betker 2020-11-24 13:20:20 -07:00
parent f3c1fc1bcd
commit cb045121b3
2 changed files with 5 additions and 1 deletions

View File

@ -159,6 +159,10 @@ def define_G(opt, opt_net, scale=None):
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
headless=True, output_mode=output_mode) headless=True, output_mode=output_mode)
elif which_model == 'rrdb_srflow':
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'])
else: else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG return netG

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_imgsetext_rrdb8x.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()