diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 6dae09e2..7c17f17a 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -4,10 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision -from torch.utils.checkpoint import checkpoint_sequential from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu -from utils.util import checkpoint +from utils.util import checkpoint, sequential_checkpoint class ResidualDenseBlock(nn.Module): @@ -251,7 +250,7 @@ class RRDBNet(nn.Module): else: x_lg = x feat = self.conv_first(x_lg) - feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat) + feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat) feat = feat[:, :self.reduce_ch] body_feat = self.conv_body(feat) feat = feat + body_feat @@ -353,7 +352,7 @@ class RRDBDiscriminator(nn.Module): def forward(self, x): feat = self.conv_first(x) - feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat) + feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat) pred = checkpoint(self.tail, feat) self.pred_ = pred.detach().clone() return pred diff --git a/codes/models/archs/srflow_orig/RRDBNet_arch.py b/codes/models/archs/srflow_orig/RRDBNet_arch.py index 7028e163..1fff45c1 100644 --- a/codes/models/archs/srflow_orig/RRDBNet_arch.py +++ b/codes/models/archs/srflow_orig/RRDBNet_arch.py @@ -119,7 +119,7 @@ class RRDBWithBypass(nn.Module): class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, initial_conv_stride=1, opt=None): self.opt = opt super(RRDBNet, self).__init__() @@ -130,7 +130,10 @@ class RRDBNet(nn.Module): RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc) self.scale = scale - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + if initial_conv_stride == 1: + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + else: + self.conv_first = nn.Conv2d(in_nc, nf, 7, stride=initial_conv_stride, padding=3, bias=True) self.body = mutil.make_layer(RRDB_block_f, nb) self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling @@ -204,26 +207,6 @@ class RRDBNet(nn.Module): fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False if fea_upn1_en: results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) - elif self.scale == 2: - # "Pretend" this is is 4x by shuffling around the inputs a bit. - half = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) - quarter = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) - eighth = F.interpolate(last_lr_fea, scale_factor=1/8, mode='bilinear', align_corners=False, recompute_scale_factor=True) - results = {'last_lr_fea': half, - 'fea_up1': half, - 'fea_up2': last_lr_fea, - 'fea_up4': fea_up2, - 'fea_up8': fea_up4, - 'fea_up16': fea_up8, - 'fea_up32': fea_up16, - 'out': out} - - fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False - if fea_up0_en: - results['fea_up0'] = quarter - fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False - if fea_upn1_en: - results['fea_up-1'] = eighth else: raise NotImplementedError diff --git a/codes/models/networks.py b/codes/models/networks.py index 4c4b9dfe..9ce6d909 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -171,7 +171,8 @@ def define_G(opt, opt_net, scale=None): elif which_model == 'rrdb_srflow': from models.archs.srflow_orig.RRDBNet_arch import RRDBNet netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale']) + nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], + initial_conv_stride=opt_net['initial_stride']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/test.py b/codes/test.py index f625a7fa..71f58c85 100644 --- a/codes/test.py +++ b/codes/test.py @@ -21,7 +21,7 @@ import models.networks as networks def forward_pass(model, output_dir, alteration_suffix=''): - model.feed_data(data, need_GT=need_GT) + model.feed_data(data, 0, need_GT=need_GT) model.test() visuals = model.get_current_visuals(need_GT)['rlt'].cpu() @@ -53,7 +53,7 @@ def forward_pass(model, output_dir, alteration_suffix=''): if __name__ == "__main__": #### options torch.backends.cudnn.benchmark = True - srg_analyze = False + want_metrics = False parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml') opt = option.parse(parser.parse_args().opt, is_train=False) @@ -95,6 +95,7 @@ if __name__ == "__main__": tq = tqdm(test_loader) for data in tq: need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True + need_GT = need_GT and want_metrics fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name']) fea_loss += fea_loss diff --git a/codes/train.py b/codes/train.py index 71af8037..118155d2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb4x_6bl_multires.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_2stride.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index 95e8c32d..10a79d7e 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_exd_imgsetext_rrdb4x_6bl_2stride/train_exd_imgsetext_rrdb4x_6bl_2stride.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()