Play with lambdas
This commit is contained in:
parent
0c6d7971b9
commit
fd356580c0
|
@ -19,7 +19,7 @@ class ResidualDenseBlock(nn.Module):
|
|||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels=64, growth_channels=32):
|
||||
def __init__(self, mid_channels=64, growth_channels=32, init_weight=.1):
|
||||
super(ResidualDenseBlock, self).__init__()
|
||||
for i in range(5):
|
||||
out_channels = mid_channels if i == 4 else growth_channels
|
||||
|
@ -29,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
|
|||
1, 1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
for i in range(5):
|
||||
default_init_weights(getattr(self, f'conv{i+1}'), 0.1)
|
||||
default_init_weights(getattr(self, f'conv{i+1}'), init_weight)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
|
42
codes/models/archs/lambda_rrdb.py
Normal file
42
codes/models/archs/lambda_rrdb.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from lambda_networks import LambdaLayer
|
||||
from torch.nn import GroupNorm
|
||||
|
||||
from models.archs.RRDBNet_arch import ResidualDenseBlock
|
||||
|
||||
|
||||
class LambdaRRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block.
|
||||
|
||||
Used in RRDB-Net in ESRGAN.
|
||||
|
||||
Args:
|
||||
mid_channels (int): Channel number of intermediate features.
|
||||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
||||
super(LambdaRRDB, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
|
||||
if reduce_to is None:
|
||||
reduce_to = mid_channels
|
||||
self.lam = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
|
||||
self.gn = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.scale = nn.Parameter(torch.full((1,), 1/256))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.lam(out)
|
||||
out = self.gn(out)
|
||||
return out * self.scale + x
|
|
@ -22,7 +22,7 @@ class SRFlowNet(nn.Module):
|
|||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
||||
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
|
||||
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
|
||||
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
|
||||
self.RRDB.load_state_dict(rrdb_state_dict, strict=False)
|
||||
|
||||
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
||||
hidden_channels = hidden_channels or 64
|
||||
|
@ -140,7 +140,7 @@ class SRFlowNet(nn.Module):
|
|||
|
||||
def rrdbPreprocessing(self, lr):
|
||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
||||
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||
block_idxs = opt_get(self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'blocks']) or []
|
||||
if len(block_idxs) > 0:
|
||||
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
||||
|
||||
|
|
|
@ -37,19 +37,19 @@ def define_G(opt, opt_net, scale=None):
|
|||
if which_model == 'MSRResNet':
|
||||
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == 'RRDBNet':
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive'
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||
output_mode=output_mode)
|
||||
elif which_model == 'RRDBNetBypass':
|
||||
elif 'RRDBNet' in which_model:
|
||||
if which_model == 'RRDBNetBypass':
|
||||
from models.archs.lambda_rrdb import LambdaRRDB
|
||||
block = LambdaRRDB
|
||||
elif which_model == 'RRDBNetLambda':
|
||||
block = RRDBNet_arch.RRDBWithBypass
|
||||
else:
|
||||
block = RRDBNet_arch.RRDB
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
|
||||
additive_mode=additive_mode, output_mode=output_mode)
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||
output_mode=output_mode, body_block=block)
|
||||
elif which_model == 'rcan':
|
||||
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
||||
opt_net['rgb_range'] = 255
|
||||
|
|
|
@ -20,45 +20,6 @@ import torch
|
|||
import models.networks as networks
|
||||
|
||||
|
||||
# Concepts: Swap transformations around. Normalize attention. Disable individual switches, both randomly and one at
|
||||
# a time, starting at the last switch. Pick random regions in an image and print out the full attention vector for
|
||||
# each switch. Yield an output directory name for each alteration and None when last alteration is completed.
|
||||
def alter_srg(srg: srg.ConfigurableSwitchedResidualGenerator2):
|
||||
# First alteration, strip off switches one at a time.
|
||||
yield "naked"
|
||||
|
||||
'''
|
||||
for i in range(1, len(srg.switches)):
|
||||
srg.switches = srg.switches[:-i]
|
||||
yield "stripped-%i" % (i,)
|
||||
'''
|
||||
|
||||
for sw in srg.switches:
|
||||
sw.set_temperature(.001)
|
||||
yield "specific"
|
||||
|
||||
for sw in srg.switches:
|
||||
sw.set_temperature(1000)
|
||||
yield "normalized"
|
||||
|
||||
for sw in srg.switches:
|
||||
sw.set_temperature(1)
|
||||
sw.switch.attention_norm = None
|
||||
yield "no_anorm"
|
||||
return None
|
||||
|
||||
def analyze_srg(srg: srg.ConfigurableSwitchedResidualGenerator2, path, alteration_suffix):
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in srg.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [torch.histc(i[1].clone().detach().cpu().flatten().float(), bins=srg.transformation_counts) for i in mean_hists]
|
||||
hists = [h / torch.sum(h) for h in hists]
|
||||
for i in range(len(means)):
|
||||
print("%s - switch_%i_specificity" % (alteration_suffix, i), means[i])
|
||||
print("%s - switch_%i_histogram" % (alteration_suffix, i), hists[i])
|
||||
|
||||
[save_attention_to_image_rgb(path, srg.attentions[i], srg.transformation_counts, alteration_suffix, i) for i in range(len(srg.attentions))]
|
||||
|
||||
|
||||
def forward_pass(model, output_dir, alteration_suffix=''):
|
||||
model.feed_data(data, need_GT=need_GT)
|
||||
model.test()
|
||||
|
@ -135,27 +96,9 @@ if __name__ == "__main__":
|
|||
for data in tq:
|
||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||
|
||||
if srg_analyze:
|
||||
orig_model = model.netG
|
||||
model_copy = networks.define_G(opt).to(model.device)
|
||||
model_copy.load_state_dict(orig_model.state_dict())
|
||||
model.netG = model_copy
|
||||
for alteration_suffix in alter_srg(model_copy):
|
||||
alt_path = osp.join(dataset_dir, alteration_suffix)
|
||||
img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
|
||||
img_name = osp.splitext(osp.basename(img_path))[0] + opt['name']
|
||||
alteration_suffix += img_name
|
||||
os.makedirs(alt_path, exist_ok=True)
|
||||
forward_pass(model, dataset_dir, alteration_suffix)
|
||||
analyze_srg(model_copy, alt_path, alteration_suffix)
|
||||
# Reset model and do next alteration.
|
||||
model_copy = networks.define_G(opt).to(model.device)
|
||||
model_copy.load_state_dict(orig_model.state_dict())
|
||||
model.netG = model_copy
|
||||
else:
|
||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
|
||||
fea_loss += fea_loss
|
||||
psnr_loss += psnr_loss
|
||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
|
||||
fea_loss += fea_loss
|
||||
psnr_loss += psnr_loss
|
||||
|
||||
# log
|
||||
logger.info('# Validation # Fea: {:.4e}, PSNR: {:.4e}'.format(fea_loss / len(test_loader), psnr_loss / len(test_loader)))
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_adrianna_srflow8x.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_imgsetext_rrdb8x.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_lambda.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user