Misc RRDB changes
This commit is contained in:
parent
f2422f1d75
commit
da604752e6
|
@ -4,10 +4,9 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.checkpoint import checkpoint_sequential
|
|
||||||
|
|
||||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint, sequential_checkpoint
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock(nn.Module):
|
class ResidualDenseBlock(nn.Module):
|
||||||
|
@ -251,7 +250,7 @@ class RRDBNet(nn.Module):
|
||||||
else:
|
else:
|
||||||
x_lg = x
|
x_lg = x
|
||||||
feat = self.conv_first(x_lg)
|
feat = self.conv_first(x_lg)
|
||||||
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||||
feat = feat[:, :self.reduce_ch]
|
feat = feat[:, :self.reduce_ch]
|
||||||
body_feat = self.conv_body(feat)
|
body_feat = self.conv_body(feat)
|
||||||
feat = feat + body_feat
|
feat = feat + body_feat
|
||||||
|
@ -353,7 +352,7 @@ class RRDBDiscriminator(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
feat = self.conv_first(x)
|
feat = self.conv_first(x)
|
||||||
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||||
pred = checkpoint(self.tail, feat)
|
pred = checkpoint(self.tail, feat)
|
||||||
self.pred_ = pred.detach().clone()
|
self.pred_ = pred.detach().clone()
|
||||||
return pred
|
return pred
|
||||||
|
|
|
@ -119,7 +119,7 @@ class RRDBWithBypass(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class RRDBNet(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, initial_conv_stride=1, opt=None):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
|
|
||||||
|
@ -130,7 +130,10 @@ class RRDBNet(nn.Module):
|
||||||
RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
|
RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
if initial_conv_stride == 1:
|
||||||
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
else:
|
||||||
|
self.conv_first = nn.Conv2d(in_nc, nf, 7, stride=initial_conv_stride, padding=3, bias=True)
|
||||||
self.body = mutil.make_layer(RRDB_block_f, nb)
|
self.body = mutil.make_layer(RRDB_block_f, nb)
|
||||||
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
#### upsampling
|
#### upsampling
|
||||||
|
@ -204,26 +207,6 @@ class RRDBNet(nn.Module):
|
||||||
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||||
if fea_upn1_en:
|
if fea_upn1_en:
|
||||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
elif self.scale == 2:
|
|
||||||
# "Pretend" this is is 4x by shuffling around the inputs a bit.
|
|
||||||
half = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
|
||||||
quarter = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
|
||||||
eighth = F.interpolate(last_lr_fea, scale_factor=1/8, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
|
||||||
results = {'last_lr_fea': half,
|
|
||||||
'fea_up1': half,
|
|
||||||
'fea_up2': last_lr_fea,
|
|
||||||
'fea_up4': fea_up2,
|
|
||||||
'fea_up8': fea_up4,
|
|
||||||
'fea_up16': fea_up8,
|
|
||||||
'fea_up32': fea_up16,
|
|
||||||
'out': out}
|
|
||||||
|
|
||||||
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
|
|
||||||
if fea_up0_en:
|
|
||||||
results['fea_up0'] = quarter
|
|
||||||
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
|
||||||
if fea_upn1_en:
|
|
||||||
results['fea_up-1'] = eighth
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -171,7 +171,8 @@ def define_G(opt, opt_net, scale=None):
|
||||||
elif which_model == 'rrdb_srflow':
|
elif which_model == 'rrdb_srflow':
|
||||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||||
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'])
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||||
|
initial_conv_stride=opt_net['initial_stride'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
|
|
@ -21,7 +21,7 @@ import models.networks as networks
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(model, output_dir, alteration_suffix=''):
|
def forward_pass(model, output_dir, alteration_suffix=''):
|
||||||
model.feed_data(data, need_GT=need_GT)
|
model.feed_data(data, 0, need_GT=need_GT)
|
||||||
model.test()
|
model.test()
|
||||||
|
|
||||||
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
|
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
|
||||||
|
@ -53,7 +53,7 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#### options
|
#### options
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
srg_analyze = False
|
want_metrics = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_4x_psnr.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
|
@ -95,6 +95,7 @@ if __name__ == "__main__":
|
||||||
tq = tqdm(test_loader)
|
tq = tqdm(test_loader)
|
||||||
for data in tq:
|
for data in tq:
|
||||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||||
|
need_GT = need_GT and want_metrics
|
||||||
|
|
||||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
|
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
|
||||||
fea_loss += fea_loss
|
fea_loss += fea_loss
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb4x_6bl_multires.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_2stride.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_exd_imgsetext_rrdb4x_6bl_2stride/train_exd_imgsetext_rrdb4x_6bl_2stride.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user