Misc RRDB changes
This commit is contained in:
parent
f2422f1d75
commit
da604752e6
|
@ -4,10 +4,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch.utils.checkpoint import checkpoint_sequential
|
||||
|
||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from utils.util import checkpoint
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
|
||||
|
||||
class ResidualDenseBlock(nn.Module):
|
||||
|
@ -251,7 +250,7 @@ class RRDBNet(nn.Module):
|
|||
else:
|
||||
x_lg = x
|
||||
feat = self.conv_first(x_lg)
|
||||
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||
feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||
feat = feat[:, :self.reduce_ch]
|
||||
body_feat = self.conv_body(feat)
|
||||
feat = feat + body_feat
|
||||
|
@ -353,7 +352,7 @@ class RRDBDiscriminator(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
feat = self.conv_first(x)
|
||||
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||
feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||
pred = checkpoint(self.tail, feat)
|
||||
self.pred_ = pred.detach().clone()
|
||||
return pred
|
||||
|
|
|
@ -119,7 +119,7 @@ class RRDBWithBypass(nn.Module):
|
|||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, initial_conv_stride=1, opt=None):
|
||||
self.opt = opt
|
||||
super(RRDBNet, self).__init__()
|
||||
|
||||
|
@ -130,7 +130,10 @@ class RRDBNet(nn.Module):
|
|||
RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
|
||||
self.scale = scale
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
if initial_conv_stride == 1:
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
else:
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 7, stride=initial_conv_stride, padding=3, bias=True)
|
||||
self.body = mutil.make_layer(RRDB_block_f, nb)
|
||||
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
|
@ -204,26 +207,6 @@ class RRDBNet(nn.Module):
|
|||
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||
if fea_upn1_en:
|
||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
elif self.scale == 2:
|
||||
# "Pretend" this is is 4x by shuffling around the inputs a bit.
|
||||
half = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
quarter = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
eighth = F.interpolate(last_lr_fea, scale_factor=1/8, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
results = {'last_lr_fea': half,
|
||||
'fea_up1': half,
|
||||
'fea_up2': last_lr_fea,
|
||||
'fea_up4': fea_up2,
|
||||
'fea_up8': fea_up4,
|
||||
'fea_up16': fea_up8,
|
||||
'fea_up32': fea_up16,
|
||||
'out': out}
|
||||
|
||||
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
|
||||
if fea_up0_en:
|
||||
results['fea_up0'] = quarter
|
||||
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||
if fea_upn1_en:
|
||||
results['fea_up-1'] = eighth
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -171,7 +171,8 @@ def define_G(opt, opt_net, scale=None):
|
|||
elif which_model == 'rrdb_srflow':
|
||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'])
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||
initial_conv_stride=opt_net['initial_stride'])
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
|
|
@ -21,7 +21,7 @@ import models.networks as networks
|
|||
|
||||
|
||||
def forward_pass(model, output_dir, alteration_suffix=''):
|
||||
model.feed_data(data, need_GT=need_GT)
|
||||
model.feed_data(data, 0, need_GT=need_GT)
|
||||
model.test()
|
||||
|
||||
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
|
||||
|
@ -53,7 +53,7 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
|||
if __name__ == "__main__":
|
||||
#### options
|
||||
torch.backends.cudnn.benchmark = True
|
||||
srg_analyze = False
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
|
@ -95,6 +95,7 @@ if __name__ == "__main__":
|
|||
tq = tqdm(test_loader)
|
||||
for data in tq:
|
||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||
need_GT = need_GT and want_metrics
|
||||
|
||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
|
||||
fea_loss += fea_loss
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb4x_6bl_multires.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_2stride.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_exd_imgsetext_rrdb4x_6bl_2stride/train_exd_imgsetext_rrdb4x_6bl_2stride.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user