forked from mrq/DL-Art-School
Add injected noise into bypass maps
This commit is contained in:
parent
04961b91cf
commit
de10c7246a
|
@ -1,4 +1,6 @@
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -7,7 +9,7 @@ import torchvision
|
||||||
|
|
||||||
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, sequential_checkpoint
|
from utils.util import checkpoint, sequential_checkpoint, opt_get
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock(nn.Module):
|
class ResidualDenseBlock(nn.Module):
|
||||||
|
@ -106,7 +108,7 @@ class RRDBWithBypass(nn.Module):
|
||||||
growth_channels (int): Channels for each growth.
|
growth_channels (int): Channels for each growth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
def __init__(self, mid_channels, growth_channels=32, reduce_to=None, randomly_add_noise_to_bypass=False):
|
||||||
super(RRDBWithBypass, self).__init__()
|
super(RRDBWithBypass, self).__init__()
|
||||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
|
@ -122,6 +124,7 @@ class RRDBWithBypass(nn.Module):
|
||||||
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
|
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
|
||||||
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
|
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
|
||||||
nn.Sigmoid())
|
nn.Sigmoid())
|
||||||
|
self.randomly_add_bypass_noise = randomly_add_noise_to_bypass
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Forward function.
|
"""Forward function.
|
||||||
|
@ -142,6 +145,11 @@ class RRDBWithBypass(nn.Module):
|
||||||
out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1)
|
out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1)
|
||||||
|
|
||||||
bypass = self.bypass(torch.cat([x, out], dim=1))
|
bypass = self.bypass(torch.cat([x, out], dim=1))
|
||||||
|
# The purpose of random noise is to induce usage of bypass maps that would otherwise be "dead". Theoretically
|
||||||
|
# if these maps provide value, the noise should trigger gradients to flow into the bypass conv network again.
|
||||||
|
if self.randomly_add_bypass_noise and random.random() < .2:
|
||||||
|
rnoise = torch.rand_like(bypass) * .02
|
||||||
|
bypass = (bypass + rnoise).clamp(0, 1)
|
||||||
self.bypass_map = bypass.detach().clone()
|
self.bypass_map = bypass.detach().clone()
|
||||||
|
|
||||||
# Empirically, we use 0.2 to scale the residual for better performance
|
# Empirically, we use 0.2 to scale the residual for better performance
|
||||||
|
@ -379,9 +387,11 @@ def register_RRDBNetBypass(opt_net, opt):
|
||||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||||
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
|
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
|
||||||
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
|
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
|
||||||
|
bypass_noise = opt_get(opt_net, ['bypass_noise'], False)
|
||||||
|
block = functools.partial(RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise)
|
||||||
return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||||
output_mode=output_mode, body_block=RRDBWithBypass, scale=opt_net['scale'], growth_channels=gc,
|
output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
|
||||||
initial_stride=initial_stride)
|
initial_stride=initial_stride)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user