Add residual blocks to pyramid disc
This commit is contained in:
parent
b4136d766a
commit
12b57bbd03
|
@ -139,6 +139,30 @@ class ResidualBlock_noBN(nn.Module):
|
||||||
return identity + out
|
return identity + out
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlockGN(nn.Module):
|
||||||
|
'''Residual block with GroupNorm
|
||||||
|
---Conv-GN-ReLU-Conv-+-
|
||||||
|
|________________|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, nf=64):
|
||||||
|
super(ResidualBlockGN, self).__init__()
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.BN1 = nn.GroupNorm(8, nf)
|
||||||
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.BN2 = nn.GroupNorm(8, nf)
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
initialize_weights([self.conv1, self.conv2], 0.1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
out = self.lrelu(self.BN1(self.conv1(x)))
|
||||||
|
out = self.BN2(self.conv2(out))
|
||||||
|
return identity + out
|
||||||
|
|
||||||
|
|
||||||
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
||||||
"""Warp an image or feature map with optical flow
|
"""Warp an image or feature map with optical flow
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass
|
from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu
|
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
||||||
from models.archs.pyramid_arch import Pyramid
|
from models.archs.pyramid_arch import Pyramid
|
||||||
|
@ -666,10 +666,15 @@ class PyramidDiscriminator(nn.Module):
|
||||||
def __init__(self, in_nc, nf, block=ConvGnLelu):
|
def __init__(self, in_nc, nf, block=ConvGnLelu):
|
||||||
super(PyramidDiscriminator, self).__init__()
|
super(PyramidDiscriminator, self).__init__()
|
||||||
self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True)
|
self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True)
|
||||||
self.top_proc = nn.Sequential(*[ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False, norm=True, activation=True)])
|
self.top_proc = nn.Sequential(*[ResidualBlockGN(nf),
|
||||||
|
ResidualBlockGN(nf),
|
||||||
|
ResidualBlockGN(nf)])
|
||||||
self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2,
|
self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2,
|
||||||
scale_per_level=1.5, norm=True, return_outlevels=False)
|
scale_per_level=1.5, norm=True, return_outlevels=False)
|
||||||
self.bottom_proc = nn.Sequential(*[
|
self.bottom_proc = nn.Sequential(*[ResidualBlockGN(nf),
|
||||||
|
ResidualBlockGN(nf),
|
||||||
|
ResidualBlockGN(nf),
|
||||||
|
ResidualBlockGN(nf),
|
||||||
ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True),
|
ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True),
|
||||||
ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True),
|
ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True),
|
||||||
ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])
|
ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user