Update feature discriminator further
Move the feature/disc losses closer and add a feature computation layer.
This commit is contained in:
parent
46aa776fbb
commit
7f7e17e291
|
@ -259,16 +259,9 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
|||
|
||||
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
||||
self.fea_proc = ConvGnLelu(nf * 8, nf * 8, bias=True, norm=False, activation=False)
|
||||
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
|
||||
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
||||
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
|
||||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.feature_mode = feature_mode
|
||||
|
||||
def forward(self, x, output_feature_vector=False):
|
||||
|
@ -284,33 +277,19 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
|||
fea3 = self.conv3_0(fea2)
|
||||
fea3 = self.conv3_1(fea3)
|
||||
|
||||
feat = self.conv4_0(fea3)
|
||||
fea4 = self.conv4_1(feat)
|
||||
fea4 = self.conv4_0(fea3)
|
||||
fea4 = self.conv4_1(fea4)
|
||||
|
||||
# And the pyramid network!
|
||||
u1 = self.up1(fea4, fea3)
|
||||
loss1 = self.collapse1(self.proc1(u1))
|
||||
u2 = self.up2(u1, fea2)
|
||||
loss2 = self.collapse2(self.proc2(u2))
|
||||
u3 = self.up3(u2, fea1)
|
||||
loss3 = self.collapse3(self.proc3(u3))
|
||||
res = loss3.shape[2:]
|
||||
fea_out = self.fea_proc(u1)
|
||||
|
||||
if self.feature_mode:
|
||||
combined_losses = F.interpolate(loss1, scale_factor=4)
|
||||
else:
|
||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||
# then know how to handle them.
|
||||
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
||||
F.interpolate(loss2, scale_factor=2),
|
||||
F.interpolate(loss3, scale_factor=1)], dim=1)
|
||||
combined_losses = F.interpolate(loss1, scale_factor=4)
|
||||
if output_feature_vector:
|
||||
return combined_losses.view(-1, 1), feat
|
||||
return combined_losses.view(-1, 1), fea_out
|
||||
else:
|
||||
return combined_losses.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
if self.feature_mode:
|
||||
return 1, 4
|
||||
else:
|
||||
return 3, 4
|
||||
return 1, 4
|
Loading…
Reference in New Issue
Block a user