Misc RRDB changes

This commit is contained in:
James Betker 2020-11-29 12:21:31 -07:00
parent f2422f1d75
commit da604752e6
6 changed files with 15 additions and 31 deletions

View File

@ -4,10 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.checkpoint import checkpoint_sequential
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
from utils.util import checkpoint
from utils.util import checkpoint, sequential_checkpoint
class ResidualDenseBlock(nn.Module):
@ -251,7 +250,7 @@ class RRDBNet(nn.Module):
else:
x_lg = x
feat = self.conv_first(x_lg)
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
feat = feat[:, :self.reduce_ch]
body_feat = self.conv_body(feat)
feat = feat + body_feat
@ -353,7 +352,7 @@ class RRDBDiscriminator(nn.Module):
def forward(self, x):
feat = self.conv_first(x)
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
pred = checkpoint(self.tail, feat)
self.pred_ = pred.detach().clone()
return pred

View File

@ -119,7 +119,7 @@ class RRDBWithBypass(nn.Module):
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, initial_conv_stride=1, opt=None):
self.opt = opt
super(RRDBNet, self).__init__()
@ -130,7 +130,10 @@ class RRDBNet(nn.Module):
RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
if initial_conv_stride == 1:
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
else:
self.conv_first = nn.Conv2d(in_nc, nf, 7, stride=initial_conv_stride, padding=3, bias=True)
self.body = mutil.make_layer(RRDB_block_f, nb)
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
@ -204,26 +207,6 @@ class RRDBNet(nn.Module):
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
if fea_upn1_en:
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
elif self.scale == 2:
# "Pretend" this is is 4x by shuffling around the inputs a bit.
half = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
quarter = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
eighth = F.interpolate(last_lr_fea, scale_factor=1/8, mode='bilinear', align_corners=False, recompute_scale_factor=True)
results = {'last_lr_fea': half,
'fea_up1': half,
'fea_up2': last_lr_fea,
'fea_up4': fea_up2,
'fea_up8': fea_up4,
'fea_up16': fea_up8,
'fea_up32': fea_up16,
'out': out}
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
if fea_up0_en:
results['fea_up0'] = quarter
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
if fea_upn1_en:
results['fea_up-1'] = eighth
else:
raise NotImplementedError

View File

@ -171,7 +171,8 @@ def define_G(opt, opt_net, scale=None):
elif which_model == 'rrdb_srflow':
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'])
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
initial_conv_stride=opt_net['initial_stride'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -21,7 +21,7 @@ import models.networks as networks
def forward_pass(model, output_dir, alteration_suffix=''):
model.feed_data(data, need_GT=need_GT)
model.feed_data(data, 0, need_GT=need_GT)
model.test()
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
@ -53,7 +53,7 @@ def forward_pass(model, output_dir, alteration_suffix=''):
if __name__ == "__main__":
#### options
torch.backends.cudnn.benchmark = True
srg_analyze = False
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
@ -95,6 +95,7 @@ if __name__ == "__main__":
tq = tqdm(test_loader)
for data in tq:
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
need_GT = need_GT and want_metrics
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
fea_loss += fea_loss

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb4x_6bl_multires.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_2stride.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_exd_imgsetext_rrdb4x_6bl_2stride/train_exd_imgsetext_rrdb4x_6bl_2stride.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()