diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 1a0a52e6..6dae09e2 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -72,7 +72,7 @@ class RRDB(nn.Module): else: self.reducer = None - def forward(self, x): + def forward(self, x, return_residual=False): """Forward function. Args: @@ -88,8 +88,12 @@ class RRDB(nn.Module): out = self.reducer(out) b, f, h, w = out.shape out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1) - # Emperically, we use 0.2 to scale the residual for better performance - return out * 0.2 + x + + if return_residual: + return 0.2 * out + else: + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x class RRDBWithBypass(nn.Module): @@ -173,6 +177,7 @@ class RRDBNet(nn.Module): headless=False, feature_channels=64, # Only applicable when headless=True. How many channels are used at the trunk level. output_mode="hq_only", # Options: "hq_only", "hq+features", "features_only" + initial_stride=1, ): super(RRDBNet, self).__init__() assert output_mode in ['hq_only', 'hq+features', 'features_only'] @@ -182,7 +187,7 @@ class RRDBNet(nn.Module): self.scale = scale self.in_channels = in_channels self.output_mode = output_mode - first_conv_stride = 1 if in_channels <= 4 else scale + first_conv_stride = initial_stride if in_channels <= 4 else scale first_conv_ksize = 3 if first_conv_stride == 1 else 7 first_conv_padding = 1 if first_conv_stride == 1 else 3 if headless: diff --git a/codes/models/archs/multi_res_rrdb.py b/codes/models/archs/multi_res_rrdb.py new file mode 100644 index 00000000..d95a678e --- /dev/null +++ b/codes/models/archs/multi_res_rrdb.py @@ -0,0 +1,83 @@ +import torch.nn as nn +import torch.nn.functional as F + +from models.archs.RRDBNet_arch import RRDB +from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu +from utils.util import checkpoint + + +class MultiLevelRRDB(nn.Module): + def __init__(self, nf, gc, levels): + super().__init__() + self.levels = levels + self.level_rrdbs = nn.ModuleList([RRDB(nf, growth_channels=gc) for i in range(levels)]) + + # Trunks should be fed in in order HR->LR + def forward(self, trunk): + for i in reversed(range(self.levels)): + lvl_scale = (2**i) + lvl_res = self.level_rrdbs[i](F.interpolate(trunk, scale_factor=1/lvl_scale, mode="area"), return_residual=True) + trunk = trunk + F.interpolate(lvl_res, scale_factor=lvl_scale, mode="nearest") + return trunk + + +class MultiResRRDBNet(nn.Module): + def __init__(self, + in_channels, + out_channels, + mid_channels=64, + l1_blocks=3, + l2_blocks=4, + l3_blocks=6, + growth_channels=32, + scale=4, + ): + super().__init__() + self.scale = scale + self.in_channels = in_channels + + self.conv_first = nn.Conv2d(in_channels, mid_channels, 7, stride=1, padding=3) + + self.l3_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 3) for _ in range(l1_blocks)]) + self.l2_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)]) + self.l1_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 1) for _ in range(l3_blocks)]) + self.block_levels = [self.l3_blocks, self.l2_blocks, self.l1_blocks] + + self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + for m in [ + self.conv_first, self.conv_first, self.conv_body, self.conv_up1, + self.conv_up2, self.conv_hr, self.conv_last + ]: + if m is not None: + default_init_weights(m, 0.1) + + def forward(self, x): + trunk = self.conv_first(x) + for block_set in self.block_levels: + for block in block_set: + trunk = checkpoint(block, trunk) + + body_feat = self.conv_body(trunk) + feat = trunk + body_feat + + # upsample + out = self.lrelu( + self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + if self.scale == 4: + out = self.lrelu( + self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))) + else: + out = self.lrelu(self.conv_up2(out)) + out = self.conv_last(self.lrelu(self.conv_hr(out))) + + return out + + def visual_dbg(self, step, path): + pass diff --git a/codes/models/networks.py b/codes/models/networks.py index 638bcbd1..4c4b9dfe 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -47,9 +47,18 @@ def define_G(opt, opt_net, scale=None): block = RRDBNet_arch.RRDB additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' + gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32 + initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1 netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, - output_mode=output_mode, body_block=block, scale=opt_net['scale']) + output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc, + initial_stride=initial_stride) + elif which_model == "multires_rrdb": + from models.archs.multi_res_rrdb import MultiResRRDBNet + netG = MultiResRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], + mid_channels=opt_net['nf'], l1_blocks=opt_net['l1'], + l2_blocks=opt_net['l2'], l3_blocks=opt_net['l3'], + growth_channels=opt_net['gc'], scale=opt_net['scale']) elif which_model == 'rcan': #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats opt_net['rgb_range'] = 255 diff --git a/codes/train.py b/codes/train.py index 2f18ef64..71af8037 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_rrdbdisc.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb4x_6bl_multires.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()