diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 29c1718f..e0c99338 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -19,7 +19,7 @@ class ResidualDenseBlock(nn.Module): growth_channels (int): Channels for each growth. """ - def __init__(self, mid_channels=64, growth_channels=32): + def __init__(self, mid_channels=64, growth_channels=32, init_weight=.1): super(ResidualDenseBlock, self).__init__() for i in range(5): out_channels = mid_channels if i == 4 else growth_channels @@ -29,7 +29,7 @@ class ResidualDenseBlock(nn.Module): 1, 1)) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) for i in range(5): - default_init_weights(getattr(self, f'conv{i+1}'), 0.1) + default_init_weights(getattr(self, f'conv{i+1}'), init_weight) def forward(self, x): diff --git a/codes/models/archs/lambda_rrdb.py b/codes/models/archs/lambda_rrdb.py new file mode 100644 index 00000000..9493a6b9 --- /dev/null +++ b/codes/models/archs/lambda_rrdb.py @@ -0,0 +1,42 @@ +import torch +from torch import nn +from lambda_networks import LambdaLayer +from torch.nn import GroupNorm + +from models.archs.RRDBNet_arch import ResidualDenseBlock + + +class LambdaRRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + mid_channels (int): Channel number of intermediate features. + growth_channels (int): Channels for each growth. + """ + + def __init__(self, mid_channels, growth_channels=32, reduce_to=None): + super(LambdaRRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1) + self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1) + if reduce_to is None: + reduce_to = mid_channels + self.lam = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4) + self.gn = GroupNorm(num_groups=8, num_channels=mid_channels) + self.scale = nn.Parameter(torch.full((1,), 1/256)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + out = self.rdb1(x) + out = self.rdb2(out) + out = self.lam(out) + out = self.gn(out) + return out * self.scale + x \ No newline at end of file diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index 66a52f26..93a729ef 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -22,7 +22,7 @@ class SRFlowNet(nn.Module): self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) if 'pretrain_rrdb' in opt['networks']['generator'].keys(): rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb']) - self.RRDB.load_state_dict(rrdb_state_dict, strict=True) + self.RRDB.load_state_dict(rrdb_state_dict, strict=False) hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels']) hidden_channels = hidden_channels or 64 @@ -140,7 +140,7 @@ class SRFlowNet(nn.Module): def rrdbPreprocessing(self, lr): rrdbResults = self.RRDB(lr, get_steps=True) - block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or [] + block_idxs = opt_get(self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'blocks']) or [] if len(block_idxs) > 0: concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) diff --git a/codes/models/networks.py b/codes/models/networks.py index a900c374..e4a76f73 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -37,19 +37,19 @@ def define_G(opt, opt_net, scale=None): if which_model == 'MSRResNet': netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) - elif which_model == 'RRDBNet': - additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive' - output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' - netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, - output_mode=output_mode) - elif which_model == 'RRDBNetBypass': + elif 'RRDBNet' in which_model: + if which_model == 'RRDBNetBypass': + from models.archs.lambda_rrdb import LambdaRRDB + block = LambdaRRDB + elif which_model == 'RRDBNetLambda': + block = RRDBNet_arch.RRDBWithBypass + else: + block = RRDBNet_arch.RRDB additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass, - blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'], - additive_mode=additive_mode, output_mode=output_mode) + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, + output_mode=output_mode, body_block=block) elif which_model == 'rcan': #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats opt_net['rgb_range'] = 255 diff --git a/codes/test.py b/codes/test.py index 861ff83c..f625a7fa 100644 --- a/codes/test.py +++ b/codes/test.py @@ -20,45 +20,6 @@ import torch import models.networks as networks -# Concepts: Swap transformations around. Normalize attention. Disable individual switches, both randomly and one at -# a time, starting at the last switch. Pick random regions in an image and print out the full attention vector for -# each switch. Yield an output directory name for each alteration and None when last alteration is completed. -def alter_srg(srg: srg.ConfigurableSwitchedResidualGenerator2): - # First alteration, strip off switches one at a time. - yield "naked" - - ''' - for i in range(1, len(srg.switches)): - srg.switches = srg.switches[:-i] - yield "stripped-%i" % (i,) - ''' - - for sw in srg.switches: - sw.set_temperature(.001) - yield "specific" - - for sw in srg.switches: - sw.set_temperature(1000) - yield "normalized" - - for sw in srg.switches: - sw.set_temperature(1) - sw.switch.attention_norm = None - yield "no_anorm" - return None - -def analyze_srg(srg: srg.ConfigurableSwitchedResidualGenerator2, path, alteration_suffix): - mean_hists = [compute_attention_specificity(att, 2) for att in srg.attentions] - means = [i[0] for i in mean_hists] - hists = [torch.histc(i[1].clone().detach().cpu().flatten().float(), bins=srg.transformation_counts) for i in mean_hists] - hists = [h / torch.sum(h) for h in hists] - for i in range(len(means)): - print("%s - switch_%i_specificity" % (alteration_suffix, i), means[i]) - print("%s - switch_%i_histogram" % (alteration_suffix, i), hists[i]) - - [save_attention_to_image_rgb(path, srg.attentions[i], srg.transformation_counts, alteration_suffix, i) for i in range(len(srg.attentions))] - - def forward_pass(model, output_dir, alteration_suffix=''): model.feed_data(data, need_GT=need_GT) model.test() @@ -135,27 +96,9 @@ if __name__ == "__main__": for data in tq: need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True - if srg_analyze: - orig_model = model.netG - model_copy = networks.define_G(opt).to(model.device) - model_copy.load_state_dict(orig_model.state_dict()) - model.netG = model_copy - for alteration_suffix in alter_srg(model_copy): - alt_path = osp.join(dataset_dir, alteration_suffix) - img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] - img_name = osp.splitext(osp.basename(img_path))[0] + opt['name'] - alteration_suffix += img_name - os.makedirs(alt_path, exist_ok=True) - forward_pass(model, dataset_dir, alteration_suffix) - analyze_srg(model_copy, alt_path, alteration_suffix) - # Reset model and do next alteration. - model_copy = networks.define_G(opt).to(model.device) - model_copy.load_state_dict(orig_model.state_dict()) - model.netG = model_copy - else: - fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name']) - fea_loss += fea_loss - psnr_loss += psnr_loss + fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name']) + fea_loss += fea_loss + psnr_loss += psnr_loss # log logger.info('# Validation # Fea: {:.4e}, PSNR: {:.4e}'.format(fea_loss / len(test_loader), psnr_loss / len(test_loader))) diff --git a/codes/train.py b/codes/train.py index cc83a3d2..251a079d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_adrianna_srflow8x.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index 51f6b14a..ff16bdc4 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_imgsetext_rrdb8x.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_lambda.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()