diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index b1abc9d0..5dc8df23 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -259,16 +259,9 @@ class Discriminator_UNet_FeaOut(nn.Module): self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False) + self.fea_proc = ConvGnLelu(nf * 8, nf * 8, bias=True, norm=False, activation=False) self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False) - self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu) - self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False) - self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False) - - self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu) - self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) - self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) - self.feature_mode = feature_mode def forward(self, x, output_feature_vector=False): @@ -284,33 +277,19 @@ class Discriminator_UNet_FeaOut(nn.Module): fea3 = self.conv3_0(fea2) fea3 = self.conv3_1(fea3) - feat = self.conv4_0(fea3) - fea4 = self.conv4_1(feat) + fea4 = self.conv4_0(fea3) + fea4 = self.conv4_1(fea4) # And the pyramid network! u1 = self.up1(fea4, fea3) loss1 = self.collapse1(self.proc1(u1)) - u2 = self.up2(u1, fea2) - loss2 = self.collapse2(self.proc2(u2)) - u3 = self.up3(u2, fea1) - loss3 = self.collapse3(self.proc3(u3)) - res = loss3.shape[2:] + fea_out = self.fea_proc(u1) - if self.feature_mode: - combined_losses = F.interpolate(loss1, scale_factor=4) - else: - # Compress all of the loss values into the batch dimension. The actual loss attached to this output will - # then know how to handle them. - combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4), - F.interpolate(loss2, scale_factor=2), - F.interpolate(loss3, scale_factor=1)], dim=1) + combined_losses = F.interpolate(loss1, scale_factor=4) if output_feature_vector: - return combined_losses.view(-1, 1), feat + return combined_losses.view(-1, 1), fea_out else: return combined_losses.view(-1, 1) def pixgan_parameters(self): - if self.feature_mode: - return 1, 4 - else: - return 3, 4 \ No newline at end of file + return 1, 4 \ No newline at end of file