Add injected noise into bypass maps

This commit is contained in:
James Betker 2021-01-07 16:31:12 -07:00
parent 04961b91cf
commit de10c7246a

View File

@ -1,4 +1,6 @@
import functools
import os
import random
import torch
import torch.nn as nn
@ -7,7 +9,7 @@ import torchvision
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
from trainer.networks import register_model
from utils.util import checkpoint, sequential_checkpoint
from utils.util import checkpoint, sequential_checkpoint, opt_get
class ResidualDenseBlock(nn.Module):
@ -106,7 +108,7 @@ class RRDBWithBypass(nn.Module):
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
def __init__(self, mid_channels, growth_channels=32, reduce_to=None, randomly_add_noise_to_bypass=False):
super(RRDBWithBypass, self).__init__()
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
@ -122,6 +124,7 @@ class RRDBWithBypass(nn.Module):
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
nn.Sigmoid())
self.randomly_add_bypass_noise = randomly_add_noise_to_bypass
def forward(self, x):
"""Forward function.
@ -142,6 +145,11 @@ class RRDBWithBypass(nn.Module):
out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1)
bypass = self.bypass(torch.cat([x, out], dim=1))
# The purpose of random noise is to induce usage of bypass maps that would otherwise be "dead". Theoretically
# if these maps provide value, the noise should trigger gradients to flow into the bypass conv network again.
if self.randomly_add_bypass_noise and random.random() < .2:
rnoise = torch.rand_like(bypass) * .02
bypass = (bypass + rnoise).clamp(0, 1)
self.bypass_map = bypass.detach().clone()
# Empirically, we use 0.2 to scale the residual for better performance
@ -379,9 +387,11 @@ def register_RRDBNetBypass(opt_net, opt):
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
bypass_noise = opt_get(opt_net, ['bypass_noise'], False)
block = functools.partial(RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise)
return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode, body_block=RRDBWithBypass, scale=opt_net['scale'], growth_channels=gc,
output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
initial_stride=initial_stride)