Add injected noise into bypass maps
This commit is contained in:
parent
04961b91cf
commit
de10c7246a
|
@ -1,4 +1,6 @@
|
|||
import functools
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -7,7 +9,7 @@ import torchvision
|
|||
|
||||
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
from utils.util import checkpoint, sequential_checkpoint, opt_get
|
||||
|
||||
|
||||
class ResidualDenseBlock(nn.Module):
|
||||
|
@ -106,7 +108,7 @@ class RRDBWithBypass(nn.Module):
|
|||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
||||
def __init__(self, mid_channels, growth_channels=32, reduce_to=None, randomly_add_noise_to_bypass=False):
|
||||
super(RRDBWithBypass, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
|
@ -122,6 +124,7 @@ class RRDBWithBypass(nn.Module):
|
|||
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
|
||||
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
|
||||
nn.Sigmoid())
|
||||
self.randomly_add_bypass_noise = randomly_add_noise_to_bypass
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
@ -142,6 +145,11 @@ class RRDBWithBypass(nn.Module):
|
|||
out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1)
|
||||
|
||||
bypass = self.bypass(torch.cat([x, out], dim=1))
|
||||
# The purpose of random noise is to induce usage of bypass maps that would otherwise be "dead". Theoretically
|
||||
# if these maps provide value, the noise should trigger gradients to flow into the bypass conv network again.
|
||||
if self.randomly_add_bypass_noise and random.random() < .2:
|
||||
rnoise = torch.rand_like(bypass) * .02
|
||||
bypass = (bypass + rnoise).clamp(0, 1)
|
||||
self.bypass_map = bypass.detach().clone()
|
||||
|
||||
# Empirically, we use 0.2 to scale the residual for better performance
|
||||
|
@ -379,9 +387,11 @@ def register_RRDBNetBypass(opt_net, opt):
|
|||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
|
||||
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
|
||||
bypass_noise = opt_get(opt_net, ['bypass_noise'], False)
|
||||
block = functools.partial(RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise)
|
||||
return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||
output_mode=output_mode, body_block=RRDBWithBypass, scale=opt_net['scale'], growth_channels=gc,
|
||||
output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
|
||||
initial_stride=initial_stride)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user