forked from mrq/DL-Art-School
err3
This commit is contained in:
parent
b3a2c21250
commit
020b3361fa
|
@ -368,12 +368,14 @@ class ConvGnSilu(nn.Module):
|
|||
# Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed
|
||||
# along with the feature representation.
|
||||
class ExpansionBlock(nn.Module):
|
||||
def __init__(self, filters, block=ConvGnSilu):
|
||||
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu):
|
||||
super(ExpansionBlock, self).__init__()
|
||||
self.decimate = block(filters, filters // 2, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.process_passthrough = block(filters // 2, filters // 2, kernel_size=3, bias=True, activation=False, norm=True)
|
||||
self.conjoin = block(filters, filters // 2, kernel_size=3, bias=False, activation=True, norm=False)
|
||||
self.process = block(filters // 2, filters // 2, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
if filters_out is None:
|
||||
filters_out = filters_in // 2
|
||||
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
||||
self.conjoin = block(filters_in, filters_out, kernel_size=3, bias=False, activation=True, norm=False)
|
||||
self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
|
||||
# input is the feature signal with shape (b, f, w, h)
|
||||
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
||||
|
|
|
@ -192,17 +192,17 @@ class Discriminator_UNet(nn.Module):
|
|||
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.up1 = ExpansionBlock(nf * 8, block=ConvGnLelu)
|
||||
self.proc1 = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
||||
self.collapse1 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
|
||||
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
||||
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.up2 = ExpansionBlock(nf * 4, block=ConvGnLelu)
|
||||
self.proc2 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse2 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
|
||||
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
||||
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.up3 = ExpansionBlock(nf * 2, block=ConvGnLelu)
|
||||
self.proc3 = ConvGnLelu(nf, nf, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf, 1, bias=True, norm=False, activation=False)
|
||||
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
|
||||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
def forward(self, x, flatten=True):
|
||||
x = x[0]
|
||||
|
|
Loading…
Reference in New Issue
Block a user