Play with lambdas

This commit is contained in:
James Betker 2020-11-26 20:30:55 -07:00
parent 0c6d7971b9
commit fd356580c0
7 changed files with 61 additions and 76 deletions

View File

@ -19,7 +19,7 @@ class ResidualDenseBlock(nn.Module):
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels=64, growth_channels=32):
def __init__(self, mid_channels=64, growth_channels=32, init_weight=.1):
super(ResidualDenseBlock, self).__init__()
for i in range(5):
out_channels = mid_channels if i == 4 else growth_channels
@ -29,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
1, 1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for i in range(5):
default_init_weights(getattr(self, f'conv{i+1}'), 0.1)
default_init_weights(getattr(self, f'conv{i+1}'), init_weight)
def forward(self, x):

View File

@ -0,0 +1,42 @@
import torch
from torch import nn
from lambda_networks import LambdaLayer
from torch.nn import GroupNorm
from models.archs.RRDBNet_arch import ResidualDenseBlock
class LambdaRRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
mid_channels (int): Channel number of intermediate features.
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
super(LambdaRRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
if reduce_to is None:
reduce_to = mid_channels
self.lam = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
self.gn = GroupNorm(num_groups=8, num_channels=mid_channels)
self.scale = nn.Parameter(torch.full((1,), 1/256))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.rdb1(x)
out = self.rdb2(out)
out = self.lam(out)
out = self.gn(out)
return out * self.scale + x

View File

@ -22,7 +22,7 @@ class SRFlowNet(nn.Module):
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
self.RRDB.load_state_dict(rrdb_state_dict, strict=False)
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
hidden_channels = hidden_channels or 64
@ -140,7 +140,7 @@ class SRFlowNet(nn.Module):
def rrdbPreprocessing(self, lr):
rrdbResults = self.RRDB(lr, get_steps=True)
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
block_idxs = opt_get(self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'blocks']) or []
if len(block_idxs) > 0:
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)

View File

@ -37,19 +37,19 @@ def define_G(opt, opt_net, scale=None):
if which_model == 'MSRResNet':
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'RRDBNet':
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive'
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode)
elif which_model == 'RRDBNetBypass':
elif 'RRDBNet' in which_model:
if which_model == 'RRDBNetBypass':
from models.archs.lambda_rrdb import LambdaRRDB
block = LambdaRRDB
elif which_model == 'RRDBNetLambda':
block = RRDBNet_arch.RRDBWithBypass
else:
block = RRDBNet_arch.RRDB
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
additive_mode=additive_mode, output_mode=output_mode)
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode, body_block=block)
elif which_model == 'rcan':
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
opt_net['rgb_range'] = 255

View File

@ -20,45 +20,6 @@ import torch
import models.networks as networks
# Concepts: Swap transformations around. Normalize attention. Disable individual switches, both randomly and one at
# a time, starting at the last switch. Pick random regions in an image and print out the full attention vector for
# each switch. Yield an output directory name for each alteration and None when last alteration is completed.
def alter_srg(srg: srg.ConfigurableSwitchedResidualGenerator2):
# First alteration, strip off switches one at a time.
yield "naked"
'''
for i in range(1, len(srg.switches)):
srg.switches = srg.switches[:-i]
yield "stripped-%i" % (i,)
'''
for sw in srg.switches:
sw.set_temperature(.001)
yield "specific"
for sw in srg.switches:
sw.set_temperature(1000)
yield "normalized"
for sw in srg.switches:
sw.set_temperature(1)
sw.switch.attention_norm = None
yield "no_anorm"
return None
def analyze_srg(srg: srg.ConfigurableSwitchedResidualGenerator2, path, alteration_suffix):
mean_hists = [compute_attention_specificity(att, 2) for att in srg.attentions]
means = [i[0] for i in mean_hists]
hists = [torch.histc(i[1].clone().detach().cpu().flatten().float(), bins=srg.transformation_counts) for i in mean_hists]
hists = [h / torch.sum(h) for h in hists]
for i in range(len(means)):
print("%s - switch_%i_specificity" % (alteration_suffix, i), means[i])
print("%s - switch_%i_histogram" % (alteration_suffix, i), hists[i])
[save_attention_to_image_rgb(path, srg.attentions[i], srg.transformation_counts, alteration_suffix, i) for i in range(len(srg.attentions))]
def forward_pass(model, output_dir, alteration_suffix=''):
model.feed_data(data, need_GT=need_GT)
model.test()
@ -135,27 +96,9 @@ if __name__ == "__main__":
for data in tq:
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
if srg_analyze:
orig_model = model.netG
model_copy = networks.define_G(opt).to(model.device)
model_copy.load_state_dict(orig_model.state_dict())
model.netG = model_copy
for alteration_suffix in alter_srg(model_copy):
alt_path = osp.join(dataset_dir, alteration_suffix)
img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
img_name = osp.splitext(osp.basename(img_path))[0] + opt['name']
alteration_suffix += img_name
os.makedirs(alt_path, exist_ok=True)
forward_pass(model, dataset_dir, alteration_suffix)
analyze_srg(model_copy, alt_path, alteration_suffix)
# Reset model and do next alteration.
model_copy = networks.define_G(opt).to(model.device)
model_copy.load_state_dict(orig_model.state_dict())
model.netG = model_copy
else:
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
fea_loss += fea_loss
psnr_loss += psnr_loss
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
fea_loss += fea_loss
psnr_loss += psnr_loss
# log
logger.info('# Validation # Fea: {:.4e}, PSNR: {:.4e}'.format(fea_loss / len(test_loader), psnr_loss / len(test_loader)))

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_adrianna_srflow8x.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_imgsetext_rrdb8x.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_lambda.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()