From 9cde58be808216e538af42daaa441f6e90670fb8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 16 May 2020 18:36:30 -0600 Subject: [PATCH] Make RRDB usable in the current iteration --- .../archs/DiscriminatorResnet_arch_passthrough.py | 8 ++++++-- codes/models/archs/RRDBNet_arch.py | 15 +++++++++------ codes/models/networks.py | 3 +-- codes/train.py | 4 ++-- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py index 521f9c8b..462261f7 100644 --- a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py +++ b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py @@ -156,8 +156,12 @@ class FixupResNet(nn.Module): return nn.Sequential(*layers) def forward(self, x): - # This class expects a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input. - x, med_skip, lo_skip = x + if len(x) == 3: + # This class can take a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input. + x, med_skip, lo_skip = x + else: + # Or just a tuple with only the high res input (this assumes number_skips was set right). + x = x[0] x = self.layer0(x) if self.number_skips > 0: diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 2ec07ff1..d11170d9 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -46,10 +46,11 @@ class RRDB(nn.Module): class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, interpolation_scale_factor=2): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2): super(RRDBNet, self).__init__() RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.scale = scale self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True) self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) @@ -61,15 +62,17 @@ class RRDBNet(nn.Module): self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.interpolation_scale_factor = interpolation_scale_factor - def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest'))) - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest'))) + if self.scale >= 2: + fea = F.interpolate(fea, scale_factor=2, mode='nearest') + fea = self.lrelu(self.upconv1(fea)) + if self.scale >= 4: + fea = F.interpolate(fea, scale_factor=2, mode='nearest') + fea = self.lrelu(self.upconv2(fea)) out = self.conv_last(self.lrelu(self.HRconv(fea))) - return out + return (out,) diff --git a/codes/models/networks.py b/codes/models/networks.py index efe9ed1c..446322bc 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -25,9 +25,8 @@ def define_G(opt, net_key='network_G'): nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': # RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB. - scale_per_step = math.sqrt(scale) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step) + nf=opt_net['nf'], nb=opt_net['nb'], scale=scale) elif which_model == 'RRDBNetXL': scale_per_step = math.sqrt(scale) netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], diff --git a/codes/train.py b/codes/train.py index f81a3484..9a19a558 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vix_resgenv2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vix_rrdb_v2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -147,7 +147,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = 0 + current_step = -1 start_epoch = 0 #### training