diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 3de79f0d..7eab9987 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -3,6 +3,7 @@ import torch.nn as nn from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu import torch.nn.functional as F from models.archs.SwitchedResidualGenerator_arch import gather_2d +from utils.util import checkpoint class Discriminator_VGG_128(nn.Module): @@ -79,8 +80,10 @@ class Discriminator_VGG_128(nn.Module): class Discriminator_VGG_128_GN(nn.Module): # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. - def __init__(self, in_nc, nf, input_img_factor=1): + def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False): super(Discriminator_VGG_128_GN, self).__init__() + self.do_checkpointing = do_checkpointing + # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) @@ -113,7 +116,7 @@ class Discriminator_VGG_128_GN(nn.Module): self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) self.linear2 = nn.Linear(100, 1) - def forward(self, x): + def compute_body(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) @@ -130,63 +133,13 @@ class Discriminator_VGG_128_GN(nn.Module): fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) - - fea = fea.contiguous().view(fea.size(0), -1) - fea = self.lrelu(self.linear1(fea)) - out = self.linear2(fea) - return out - - -from utils.util import checkpoint -class Discriminator_VGG_128_GN_Checkpointed(nn.Module): - # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. - def __init__(self, in_nc, nf, input_img_factor=1): - super(Discriminator_VGG_128_GN_Checkpointed, self).__init__() - # [64, 128, 128] - conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) - bn0_1 = nn.GroupNorm(8, nf, affine=True) - # [64, 64, 64] - conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) - bn1_0 = nn.GroupNorm(8, nf * 2, affine=True) - conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) - bn1_1 = nn.GroupNorm(8, nf * 2, affine=True) - # [128, 32, 32] - conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) - bn2_0 = nn.GroupNorm(8, nf * 4, affine=True) - conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) - bn2_1 = nn.GroupNorm(8, nf * 4, affine=True) - # [256, 16, 16] - conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) - bn3_0 = nn.GroupNorm(8, nf * 8, affine=True) - conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - bn3_1 = nn.GroupNorm(8, nf * 8, affine=True) - # [512, 8, 8] - conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) - bn4_0 = nn.GroupNorm(8, nf * 8, affine=True) - conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) - final_nf = nf * 8 - - # activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - self.body = nn.Sequential(conv0_0, self.lrelu, - conv0_1, bn0_1, self.lrelu, - conv1_0, bn1_0, self.lrelu, - conv1_1, bn1_1, self.lrelu, - conv2_0, bn2_0, self.lrelu, - conv2_1, bn2_1, self.lrelu, - conv3_0, bn3_0, self.lrelu, - conv3_1, bn3_1, self.lrelu, - conv4_0, bn4_0, self.lrelu, - conv4_1, bn4_1, self.lrelu) - - self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) - self.linear2 = nn.Linear(100, 1) + return fea def forward(self, x): - fea = checkpoint(self.body, x) + if self.do_checkpointing: + fea = checkpoint(self.compute_body, x) + else: + fea = self.compute_body(x) fea = fea.contiguous().view(fea.size(0), -1) fea = self.lrelu(self.linear1(fea)) out = self.linear2(fea) diff --git a/codes/models/networks.py b/codes/models/networks.py index 531380f1..50f0851f 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -172,7 +172,7 @@ def define_D_net(opt_net, img_sz=None, wrap=False): if wrap: netD = GradDiscWrapper(netD) elif which_model == 'discriminator_vgg_128_gn_checkpointed': - netD = SRGAN_arch.Discriminator_VGG_128_GN_Checkpointed(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) + netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True) elif which_model == 'discriminator_resnet': netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) elif which_model == 'discriminator_resnet_50': diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 798264c0..5b2cb41c 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -169,8 +169,10 @@ class GeneratorGanLoss(ConfigurableLoss): if self.detach_real: pred_d_real = pred_d_real.detach() pred_g_fake = netD(*fake) + d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True) + self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + - self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + d_fake_diff) / 2 else: raise NotImplementedError if self.min_loss != 0: @@ -234,10 +236,10 @@ class DiscriminatorGanLoss(ConfigurableLoss): if self.min_loss != 0: self.loss_rotating_buffer[self.rb_ptr] = loss.item() self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] + self.metrics.append(("loss_counter", self.losses_computed)) if torch.mean(self.loss_rotating_buffer) < self.min_loss: return 0 self.losses_computed += 1 - self.metrics.append(("loss_counter", self.losses_computed)) return loss