Update feature discriminator further

Move the feature/disc losses closer and add a feature computation layer.
This commit is contained in:
James Betker 2020-07-20 20:54:45 -06:00
parent 46aa776fbb
commit 7f7e17e291

View File

@ -259,16 +259,9 @@ class Discriminator_UNet_FeaOut(nn.Module):
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False) self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.fea_proc = ConvGnLelu(nf * 8, nf * 8, bias=True, norm=False, activation=False)
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False) self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
self.feature_mode = feature_mode self.feature_mode = feature_mode
def forward(self, x, output_feature_vector=False): def forward(self, x, output_feature_vector=False):
@ -284,33 +277,19 @@ class Discriminator_UNet_FeaOut(nn.Module):
fea3 = self.conv3_0(fea2) fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3) fea3 = self.conv3_1(fea3)
feat = self.conv4_0(fea3) fea4 = self.conv4_0(fea3)
fea4 = self.conv4_1(feat) fea4 = self.conv4_1(fea4)
# And the pyramid network! # And the pyramid network!
u1 = self.up1(fea4, fea3) u1 = self.up1(fea4, fea3)
loss1 = self.collapse1(self.proc1(u1)) loss1 = self.collapse1(self.proc1(u1))
u2 = self.up2(u1, fea2) fea_out = self.fea_proc(u1)
loss2 = self.collapse2(self.proc2(u2))
u3 = self.up3(u2, fea1)
loss3 = self.collapse3(self.proc3(u3))
res = loss3.shape[2:]
if self.feature_mode: combined_losses = F.interpolate(loss1, scale_factor=4)
combined_losses = F.interpolate(loss1, scale_factor=4)
else:
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
F.interpolate(loss2, scale_factor=2),
F.interpolate(loss3, scale_factor=1)], dim=1)
if output_feature_vector: if output_feature_vector:
return combined_losses.view(-1, 1), feat return combined_losses.view(-1, 1), fea_out
else: else:
return combined_losses.view(-1, 1) return combined_losses.view(-1, 1)
def pixgan_parameters(self): def pixgan_parameters(self):
if self.feature_mode: return 1, 4
return 1, 4
else:
return 3, 4