Revert "..."

This reverts commit 4b92191880.
This commit is contained in:
James Betker 2020-11-11 17:24:27 -07:00
parent 4b92191880
commit 1c065c41b4

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
import torch.nn.functional as F
from models.archs.SwitchedResidualGenerator_arch import gather_2d
from models.archs.pyramid_arch import Pyramid
@ -666,10 +666,15 @@ class PyramidDiscriminator(nn.Module):
def __init__(self, in_nc, nf, block=ConvGnLelu):
super(PyramidDiscriminator, self).__init__()
self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True)
self.top_proc = nn.Sequential(*[ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False, norm=True, activation=True)])
self.top_proc = nn.Sequential(*[ResidualBlockGN(nf),
ResidualBlockGN(nf),
ResidualBlockGN(nf)])
self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2,
scale_per_level=1.5, norm=True, return_outlevels=False)
self.bottom_proc = nn.Sequential(*[
self.bottom_proc = nn.Sequential(*[ResidualBlockGN(nf),
ResidualBlockGN(nf),
ResidualBlockGN(nf),
ResidualBlockGN(nf),
ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True),
ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True),
ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])