Update feature discriminator further
Move the feature/disc losses closer and add a feature computation layer.
This commit is contained in:
parent
46aa776fbb
commit
7f7e17e291
|
@ -259,16 +259,9 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
||||||
|
|
||||||
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||||
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
||||||
|
self.fea_proc = ConvGnLelu(nf * 8, nf * 8, bias=True, norm=False, activation=False)
|
||||||
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
|
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
|
||||||
|
|
||||||
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
|
|
||||||
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
|
||||||
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
|
|
||||||
|
|
||||||
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
|
|
||||||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
|
||||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
|
||||||
|
|
||||||
self.feature_mode = feature_mode
|
self.feature_mode = feature_mode
|
||||||
|
|
||||||
def forward(self, x, output_feature_vector=False):
|
def forward(self, x, output_feature_vector=False):
|
||||||
|
@ -284,33 +277,19 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
||||||
fea3 = self.conv3_0(fea2)
|
fea3 = self.conv3_0(fea2)
|
||||||
fea3 = self.conv3_1(fea3)
|
fea3 = self.conv3_1(fea3)
|
||||||
|
|
||||||
feat = self.conv4_0(fea3)
|
fea4 = self.conv4_0(fea3)
|
||||||
fea4 = self.conv4_1(feat)
|
fea4 = self.conv4_1(fea4)
|
||||||
|
|
||||||
# And the pyramid network!
|
# And the pyramid network!
|
||||||
u1 = self.up1(fea4, fea3)
|
u1 = self.up1(fea4, fea3)
|
||||||
loss1 = self.collapse1(self.proc1(u1))
|
loss1 = self.collapse1(self.proc1(u1))
|
||||||
u2 = self.up2(u1, fea2)
|
fea_out = self.fea_proc(u1)
|
||||||
loss2 = self.collapse2(self.proc2(u2))
|
|
||||||
u3 = self.up3(u2, fea1)
|
|
||||||
loss3 = self.collapse3(self.proc3(u3))
|
|
||||||
res = loss3.shape[2:]
|
|
||||||
|
|
||||||
if self.feature_mode:
|
combined_losses = F.interpolate(loss1, scale_factor=4)
|
||||||
combined_losses = F.interpolate(loss1, scale_factor=4)
|
|
||||||
else:
|
|
||||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
|
||||||
# then know how to handle them.
|
|
||||||
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
|
||||||
F.interpolate(loss2, scale_factor=2),
|
|
||||||
F.interpolate(loss3, scale_factor=1)], dim=1)
|
|
||||||
if output_feature_vector:
|
if output_feature_vector:
|
||||||
return combined_losses.view(-1, 1), feat
|
return combined_losses.view(-1, 1), fea_out
|
||||||
else:
|
else:
|
||||||
return combined_losses.view(-1, 1)
|
return combined_losses.view(-1, 1)
|
||||||
|
|
||||||
def pixgan_parameters(self):
|
def pixgan_parameters(self):
|
||||||
if self.feature_mode:
|
return 1, 4
|
||||||
return 1, 4
|
|
||||||
else:
|
|
||||||
return 3, 4
|
|
Loading…
Reference in New Issue
Block a user