forked from mrq/DL-Art-School
Play with lambdas
This commit is contained in:
parent
0c6d7971b9
commit
fd356580c0
|
@ -19,7 +19,7 @@ class ResidualDenseBlock(nn.Module):
|
||||||
growth_channels (int): Channels for each growth.
|
growth_channels (int): Channels for each growth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mid_channels=64, growth_channels=32):
|
def __init__(self, mid_channels=64, growth_channels=32, init_weight=.1):
|
||||||
super(ResidualDenseBlock, self).__init__()
|
super(ResidualDenseBlock, self).__init__()
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
out_channels = mid_channels if i == 4 else growth_channels
|
out_channels = mid_channels if i == 4 else growth_channels
|
||||||
|
@ -29,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
|
||||||
1, 1))
|
1, 1))
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
default_init_weights(getattr(self, f'conv{i+1}'), 0.1)
|
default_init_weights(getattr(self, f'conv{i+1}'), init_weight)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
42
codes/models/archs/lambda_rrdb.py
Normal file
42
codes/models/archs/lambda_rrdb.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from lambda_networks import LambdaLayer
|
||||||
|
from torch.nn import GroupNorm
|
||||||
|
|
||||||
|
from models.archs.RRDBNet_arch import ResidualDenseBlock
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaRRDB(nn.Module):
|
||||||
|
"""Residual in Residual Dense Block.
|
||||||
|
|
||||||
|
Used in RRDB-Net in ESRGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mid_channels (int): Channel number of intermediate features.
|
||||||
|
growth_channels (int): Channels for each growth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
||||||
|
super(LambdaRRDB, self).__init__()
|
||||||
|
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
|
||||||
|
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
|
||||||
|
if reduce_to is None:
|
||||||
|
reduce_to = mid_channels
|
||||||
|
self.lam = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
|
||||||
|
self.gn = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||||
|
self.scale = nn.Parameter(torch.full((1,), 1/256))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Forward results.
|
||||||
|
"""
|
||||||
|
out = self.rdb1(x)
|
||||||
|
out = self.rdb2(out)
|
||||||
|
out = self.lam(out)
|
||||||
|
out = self.gn(out)
|
||||||
|
return out * self.scale + x
|
|
@ -22,7 +22,7 @@ class SRFlowNet(nn.Module):
|
||||||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
||||||
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
|
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
|
||||||
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
|
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
|
||||||
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
|
self.RRDB.load_state_dict(rrdb_state_dict, strict=False)
|
||||||
|
|
||||||
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
||||||
hidden_channels = hidden_channels or 64
|
hidden_channels = hidden_channels or 64
|
||||||
|
@ -140,7 +140,7 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
def rrdbPreprocessing(self, lr):
|
def rrdbPreprocessing(self, lr):
|
||||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
rrdbResults = self.RRDB(lr, get_steps=True)
|
||||||
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
block_idxs = opt_get(self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'blocks']) or []
|
||||||
if len(block_idxs) > 0:
|
if len(block_idxs) > 0:
|
||||||
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
||||||
|
|
||||||
|
|
|
@ -37,19 +37,19 @@ def define_G(opt, opt_net, scale=None):
|
||||||
if which_model == 'MSRResNet':
|
if which_model == 'MSRResNet':
|
||||||
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||||
elif which_model == 'RRDBNet':
|
elif 'RRDBNet' in which_model:
|
||||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive'
|
if which_model == 'RRDBNetBypass':
|
||||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
from models.archs.lambda_rrdb import LambdaRRDB
|
||||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
block = LambdaRRDB
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
elif which_model == 'RRDBNetLambda':
|
||||||
output_mode=output_mode)
|
block = RRDBNet_arch.RRDBWithBypass
|
||||||
elif which_model == 'RRDBNetBypass':
|
else:
|
||||||
|
block = RRDBNet_arch.RRDB
|
||||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
|
output_mode=output_mode, body_block=block)
|
||||||
additive_mode=additive_mode, output_mode=output_mode)
|
|
||||||
elif which_model == 'rcan':
|
elif which_model == 'rcan':
|
||||||
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
||||||
opt_net['rgb_range'] = 255
|
opt_net['rgb_range'] = 255
|
||||||
|
|
|
@ -20,45 +20,6 @@ import torch
|
||||||
import models.networks as networks
|
import models.networks as networks
|
||||||
|
|
||||||
|
|
||||||
# Concepts: Swap transformations around. Normalize attention. Disable individual switches, both randomly and one at
|
|
||||||
# a time, starting at the last switch. Pick random regions in an image and print out the full attention vector for
|
|
||||||
# each switch. Yield an output directory name for each alteration and None when last alteration is completed.
|
|
||||||
def alter_srg(srg: srg.ConfigurableSwitchedResidualGenerator2):
|
|
||||||
# First alteration, strip off switches one at a time.
|
|
||||||
yield "naked"
|
|
||||||
|
|
||||||
'''
|
|
||||||
for i in range(1, len(srg.switches)):
|
|
||||||
srg.switches = srg.switches[:-i]
|
|
||||||
yield "stripped-%i" % (i,)
|
|
||||||
'''
|
|
||||||
|
|
||||||
for sw in srg.switches:
|
|
||||||
sw.set_temperature(.001)
|
|
||||||
yield "specific"
|
|
||||||
|
|
||||||
for sw in srg.switches:
|
|
||||||
sw.set_temperature(1000)
|
|
||||||
yield "normalized"
|
|
||||||
|
|
||||||
for sw in srg.switches:
|
|
||||||
sw.set_temperature(1)
|
|
||||||
sw.switch.attention_norm = None
|
|
||||||
yield "no_anorm"
|
|
||||||
return None
|
|
||||||
|
|
||||||
def analyze_srg(srg: srg.ConfigurableSwitchedResidualGenerator2, path, alteration_suffix):
|
|
||||||
mean_hists = [compute_attention_specificity(att, 2) for att in srg.attentions]
|
|
||||||
means = [i[0] for i in mean_hists]
|
|
||||||
hists = [torch.histc(i[1].clone().detach().cpu().flatten().float(), bins=srg.transformation_counts) for i in mean_hists]
|
|
||||||
hists = [h / torch.sum(h) for h in hists]
|
|
||||||
for i in range(len(means)):
|
|
||||||
print("%s - switch_%i_specificity" % (alteration_suffix, i), means[i])
|
|
||||||
print("%s - switch_%i_histogram" % (alteration_suffix, i), hists[i])
|
|
||||||
|
|
||||||
[save_attention_to_image_rgb(path, srg.attentions[i], srg.transformation_counts, alteration_suffix, i) for i in range(len(srg.attentions))]
|
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(model, output_dir, alteration_suffix=''):
|
def forward_pass(model, output_dir, alteration_suffix=''):
|
||||||
model.feed_data(data, need_GT=need_GT)
|
model.feed_data(data, need_GT=need_GT)
|
||||||
model.test()
|
model.test()
|
||||||
|
@ -135,27 +96,9 @@ if __name__ == "__main__":
|
||||||
for data in tq:
|
for data in tq:
|
||||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||||
|
|
||||||
if srg_analyze:
|
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
|
||||||
orig_model = model.netG
|
fea_loss += fea_loss
|
||||||
model_copy = networks.define_G(opt).to(model.device)
|
psnr_loss += psnr_loss
|
||||||
model_copy.load_state_dict(orig_model.state_dict())
|
|
||||||
model.netG = model_copy
|
|
||||||
for alteration_suffix in alter_srg(model_copy):
|
|
||||||
alt_path = osp.join(dataset_dir, alteration_suffix)
|
|
||||||
img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
|
|
||||||
img_name = osp.splitext(osp.basename(img_path))[0] + opt['name']
|
|
||||||
alteration_suffix += img_name
|
|
||||||
os.makedirs(alt_path, exist_ok=True)
|
|
||||||
forward_pass(model, dataset_dir, alteration_suffix)
|
|
||||||
analyze_srg(model_copy, alt_path, alteration_suffix)
|
|
||||||
# Reset model and do next alteration.
|
|
||||||
model_copy = networks.define_G(opt).to(model.device)
|
|
||||||
model_copy.load_state_dict(orig_model.state_dict())
|
|
||||||
model.netG = model_copy
|
|
||||||
else:
|
|
||||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
|
|
||||||
fea_loss += fea_loss
|
|
||||||
psnr_loss += psnr_loss
|
|
||||||
|
|
||||||
# log
|
# log
|
||||||
logger.info('# Validation # Fea: {:.4e}, PSNR: {:.4e}'.format(fea_loss / len(test_loader), psnr_loss / len(test_loader)))
|
logger.info('# Validation # Fea: {:.4e}, PSNR: {:.4e}'.format(fea_loss / len(test_loader), psnr_loss / len(test_loader)))
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_adrianna_srflow8x.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_imgsetext_rrdb8x.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_lambda.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user