diff --git a/codes/models/RRDBNet_arch.py b/codes/models/RRDBNet_arch.py index ba4d576f..1fe25f0f 100644 --- a/codes/models/RRDBNet_arch.py +++ b/codes/models/RRDBNet_arch.py @@ -1,4 +1,6 @@ +import functools import os +import random import torch import torch.nn as nn @@ -7,7 +9,7 @@ import torchvision from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu from trainer.networks import register_model -from utils.util import checkpoint, sequential_checkpoint +from utils.util import checkpoint, sequential_checkpoint, opt_get class ResidualDenseBlock(nn.Module): @@ -106,7 +108,7 @@ class RRDBWithBypass(nn.Module): growth_channels (int): Channels for each growth. """ - def __init__(self, mid_channels, growth_channels=32, reduce_to=None): + def __init__(self, mid_channels, growth_channels=32, reduce_to=None, randomly_add_noise_to_bypass=False): super(RRDBWithBypass, self).__init__() self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) @@ -122,6 +124,7 @@ class RRDBWithBypass(nn.Module): ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False), ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), nn.Sigmoid()) + self.randomly_add_bypass_noise = randomly_add_noise_to_bypass def forward(self, x): """Forward function. @@ -142,6 +145,11 @@ class RRDBWithBypass(nn.Module): out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1) bypass = self.bypass(torch.cat([x, out], dim=1)) + # The purpose of random noise is to induce usage of bypass maps that would otherwise be "dead". Theoretically + # if these maps provide value, the noise should trigger gradients to flow into the bypass conv network again. + if self.randomly_add_bypass_noise and random.random() < .2: + rnoise = torch.rand_like(bypass) * .02 + bypass = (bypass + rnoise).clamp(0, 1) self.bypass_map = bypass.detach().clone() # Empirically, we use 0.2 to scale the residual for better performance @@ -379,9 +387,11 @@ def register_RRDBNetBypass(opt_net, opt): output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32 initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1 + bypass_noise = opt_get(opt_net, ['bypass_noise'], False) + block = functools.partial(RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise) return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, - output_mode=output_mode, body_block=RRDBWithBypass, scale=opt_net['scale'], growth_channels=gc, + output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc, initial_stride=initial_stride)