From 193cdc6636b4be1adc200451c482045dc6ce5eab Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 1 Jan 2021 15:56:09 -0700 Subject: [PATCH] Move discriminators to the create_model paradigm Also cleans up a lot of old discriminator models that I have no intention of using again. --- codes/data/byol_attachment.py | 8 +- codes/models/discriminator_vgg_arch.py | 526 +----------------- .../models/stylegan/Discriminator_StyleGAN.py | 10 +- codes/models/stylegan/stylegan2_lucidrains.py | 30 +- codes/requirements.txt | 2 +- codes/scripts/test_dataloader.py | 104 ---- codes/train.py | 2 +- codes/trainer/ExtensibleTrainer.py | 4 +- codes/trainer/networks.py | 99 +--- codes/utils/distill_torchscript.py | 177 ------ 10 files changed, 54 insertions(+), 908 deletions(-) delete mode 100644 codes/scripts/test_dataloader.py delete mode 100644 codes/utils/distill_torchscript.py diff --git a/codes/data/byol_attachment.py b/codes/data/byol_attachment.py index 5574f327..5303cdce 100644 --- a/codes/data/byol_attachment.py +++ b/codes/data/byol_attachment.py @@ -37,13 +37,15 @@ class ByolDatasetWrapper(Dataset): self.cropped_img_size = opt['crop_size'] self.key1 = opt_get(opt, ['key1'], 'hq') self.key2 = opt_get(opt, ['key2'], 'lq') + for_sr = opt_get(opt, ['for_sr'], False) # When set, color alterations and blurs are disabled. augmentations = [ \ - RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), - augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), - RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] + if not for_sr: + augmentations.extend([RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), + augs.RandomGrayscale(p=0.2), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]) if opt['normalize']: # The paper calls for normalization. Most datasets/models in this repo don't use this. # Recommend setting true if you want to train exactly like the paper. diff --git a/codes/models/discriminator_vgg_arch.py b/codes/models/discriminator_vgg_arch.py index 1f98bd0e..ebae1284 100644 --- a/codes/models/discriminator_vgg_arch.py +++ b/codes/models/discriminator_vgg_arch.py @@ -3,7 +3,9 @@ import torch.nn as nn from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN import torch.nn.functional as F -from utils.util import checkpoint + +from trainer.networks import register_model +from utils.util import checkpoint, opt_get class Discriminator_VGG_128(nn.Module): @@ -79,6 +81,12 @@ class Discriminator_VGG_128(nn.Module): return out +@register_model +def register_discriminator_vgg_128(opt_net, opt): + return Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'] / 128, + extra_conv=opt_net['extra_conv']) + + class Discriminator_VGG_128_GN(nn.Module): # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False, extra_conv=False): @@ -159,514 +167,8 @@ class Discriminator_VGG_128_GN(nn.Module): return out -class CrossCompareBlock(nn.Module): - def __init__(self, nf_in, nf_out): - super(CrossCompareBlock, self).__init__() - self.conv_hr_merge = ConvGnLelu(nf_in * 2, nf_in, kernel_size=1, bias=False, activation=False, norm=True) - self.proc_hr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True) - self.proc_lr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True) - self.reduce_hr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True) - self.reduce_lr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True) - - def forward(self, hr, lr): - hr = self.conv_hr_merge(torch.cat([hr, lr], dim=1)) - hr = self.proc_hr(hr) - hr = self.reduce_hr(hr) - - lr = self.proc_lr(lr) - lr = self.reduce_lr(lr) - - return hr, lr - - -class CrossCompareDiscriminator(nn.Module): - def __init__(self, in_nc, ref_channels, nf, scale=4): - super(CrossCompareDiscriminator, self).__init__() - assert scale == 2 or scale == 4 - - self.init_conv_hr = ConvGnLelu(in_nc, nf, stride=2, norm=False, bias=True, activation=True) - self.init_conv_lr = ConvGnLelu(ref_channels, nf, stride=1, norm=False, bias=True, activation=True) - if scale == 4: - strd_2 = 2 - else: - strd_2 = 1 - self.second_conv = ConvGnLelu(nf, nf, stride=strd_2, norm=True, bias=False, activation=True) - - self.cross1 = CrossCompareBlock(nf, nf * 2) - self.cross2 = CrossCompareBlock(nf * 2, nf * 4) - self.cross3 = CrossCompareBlock(nf * 4, nf * 8) - self.cross4 = CrossCompareBlock(nf * 8, nf * 8) - self.fproc_conv = ConvGnLelu(nf * 8, nf, norm=True, bias=True, activation=True) - self.out_conv = ConvGnLelu(nf, 1, norm=False, bias=False, activation=False) - - self.scale = scale * 16 - - def forward(self, hr, lr): - hr = self.init_conv_hr(hr) - hr = self.second_conv(hr) - lr = self.init_conv_lr(lr) - - hr, lr = self.cross1(hr, lr) - hr, lr = self.cross2(hr, lr) - hr, lr = self.cross3(hr, lr) - hr, _ = self.cross4(hr, lr) - - return self.out_conv(self.fproc_conv(hr)).view(-1, 1) - - # Returns tuple of (number_output_channels, scale_of_output_reduction (1/n)) - def pixgan_parameters(self): - return 3, self.scale - - -class Discriminator_VGG_PixLoss(nn.Module): - def __init__(self, in_nc, nf): - super(Discriminator_VGG_PixLoss, self).__init__() - # [64, 128, 128] - self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) - self.bn0_1 = nn.GroupNorm(8, nf, affine=True) - # [64, 64, 64] - self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) - self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True) - self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) - self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True) - # [128, 32, 32] - self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) - self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True) - self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) - self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True) - # [256, 16, 16] - self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) - self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True) - self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True) - # [512, 8, 8] - self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) - self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True) - self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) - - self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False) - self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False) - - # Pyramid network: upsample with residuals and produce losses at multiple resolutions. - self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, activation=False) - self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False) - self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False) - self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False) - self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False) - - self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, activation=False) - self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False) - self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False) - self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False) - self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, norm=False, activation=False) - - # activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x, flatten=True): - fea0 = self.lrelu(self.conv0_0(x)) - fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0))) - - fea1 = self.lrelu(self.bn1_0(self.conv1_0(fea0))) - fea1 = self.lrelu(self.bn1_1(self.conv1_1(fea1))) - - fea2 = self.lrelu(self.bn2_0(self.conv2_0(fea1))) - fea2 = self.lrelu(self.bn2_1(self.conv2_1(fea2))) - - fea3 = self.lrelu(self.bn3_0(self.conv3_0(fea2))) - fea3 = self.lrelu(self.bn3_1(self.conv3_1(fea3))) - - fea4 = self.lrelu(self.bn4_0(self.conv4_0(fea3))) - fea4 = self.lrelu(self.bn4_1(self.conv4_1(fea4))) - - loss = self.reduce_1(fea4) - # "Weight" all losses the same by interpolating them to the highest dimension. - loss = self.pix_loss_collapse(loss) - loss = F.interpolate(loss, scale_factor=4, mode="nearest") - - # And the pyramid network! - dec3 = self.up3_decimate(F.interpolate(fea4, scale_factor=2, mode="nearest")) - dec3 = torch.cat([dec3, fea3], dim=1) - dec3 = self.up3_converge(dec3) - dec3 = self.up3_proc(dec3) - loss3 = self.up3_reduce(dec3) - loss3 = self.up3_pix(loss3) - loss3 = F.interpolate(loss3, scale_factor=2, mode="nearest") - - dec2 = self.up2_decimate(F.interpolate(dec3, scale_factor=2, mode="nearest")) - dec2 = torch.cat([dec2, fea2], dim=1) - dec2 = self.up2_converge(dec2) - dec2 = self.up2_proc(dec2) - dec2 = self.up2_reduce(dec2) - loss2 = self.up2_pix(dec2) - - # Compress all of the loss values into the batch dimension. The actual loss attached to this output will - # then know how to handle them. - combined_losses = torch.cat([loss, loss3, loss2], dim=1) - return combined_losses.view(-1, 1) - - def pixgan_parameters(self): - return 3, 8 - - -class Discriminator_UNet(nn.Module): - def __init__(self, in_nc, nf): - super(Discriminator_UNet, self).__init__() - # [64, 128, 128] - self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) - self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) - # [64, 64, 64] - self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) - self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) - # [128, 32, 32] - self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) - self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) - # [256, 16, 16] - self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False) - self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) - # [512, 8, 8] - self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False) - self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) - - self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) - self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False) - self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False) - - self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu) - self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False) - self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False) - - self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu) - self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) - self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) - - def forward(self, x, flatten=True): - fea0 = self.conv0_0(x) - fea0 = self.conv0_1(fea0) - - fea1 = self.conv1_0(fea0) - fea1 = self.conv1_1(fea1) - - fea2 = self.conv2_0(fea1) - fea2 = self.conv2_1(fea2) - - fea3 = self.conv3_0(fea2) - fea3 = self.conv3_1(fea3) - - fea4 = self.conv4_0(fea3) - fea4 = self.conv4_1(fea4) - - # And the pyramid network! - u1 = self.up1(fea4, fea3) - loss1 = self.collapse1(self.proc1(u1)) - u2 = self.up2(u1, fea2) - loss2 = self.collapse2(self.proc2(u2)) - u3 = self.up3(u2, fea1) - loss3 = self.collapse3(self.proc3(u3)) - res = loss3.shape[2:] - - # Compress all of the loss values into the batch dimension. The actual loss attached to this output will - # then know how to handle them. - combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4), - F.interpolate(loss2, scale_factor=2), - F.interpolate(loss3, scale_factor=1)], dim=1) - return combined_losses.view(-1, 1) - - def pixgan_parameters(self): - return 3, 4 - - -class Discriminator_UNet_FeaOut(nn.Module): - def __init__(self, in_nc, nf, feature_mode=False): - super(Discriminator_UNet_FeaOut, self).__init__() - # [64, 128, 128] - self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) - self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) - # [64, 64, 64] - self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) - self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) - # [128, 32, 32] - self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) - self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) - # [256, 16, 16] - self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False) - self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) - # [512, 8, 8] - self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False) - self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) - - self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) - self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False) - self.fea_proc = ConvGnLelu(nf * 8, nf * 8, bias=True, norm=False, activation=False) - self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False) - - self.feature_mode = feature_mode - - def forward(self, x, output_feature_vector=False): - fea0 = self.conv0_0(x) - fea0 = self.conv0_1(fea0) - - fea1 = self.conv1_0(fea0) - fea1 = self.conv1_1(fea1) - - fea2 = self.conv2_0(fea1) - fea2 = self.conv2_1(fea2) - - fea3 = self.conv3_0(fea2) - fea3 = self.conv3_1(fea3) - - fea4 = self.conv4_0(fea3) - fea4 = self.conv4_1(fea4) - - # And the pyramid network! - u1 = self.up1(fea4, fea3) - loss1 = self.collapse1(self.proc1(u1)) - fea_out = self.fea_proc(u1) - - combined_losses = F.interpolate(loss1, scale_factor=4) - if output_feature_vector: - return combined_losses.view(-1, 1), fea_out - else: - return combined_losses.view(-1, 1) - - def pixgan_parameters(self): - return 1, 4 - - -class Vgg128GnHead(nn.Module): - def __init__(self, in_nc, nf, depth=5): - super(Vgg128GnHead, self).__init__() - assert depth == 4 or depth == 5 # Nothing stopping others from being implemented, just not done yet. - self.depth = depth - - # [64, 128, 128] - self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) - self.bn0_1 = nn.GroupNorm(8, nf, affine=True) - # [64, 64, 64] - self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) - self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True) - self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) - self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True) - # [128, 32, 32] - self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) - self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True) - self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) - self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True) - # [256, 16, 16] - self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) - self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True) - self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True) - if depth > 4: - # [512, 8, 8] - self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) - self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True) - self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) - - # activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.lrelu(self.conv0_0(x)) - fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) - - fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) - fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) - - fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) - fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) - - fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) - fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) - - if self.depth > 4: - fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) - fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) - return fea - - -class RefDiscriminatorVgg128(nn.Module): - # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. - def __init__(self, in_nc, nf, input_img_factor=1): - super(RefDiscriminatorVgg128, self).__init__() - - # activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - self.feature_head = Vgg128GnHead(in_nc, nf) - self.ref_head = Vgg128GnHead(in_nc+1, nf, depth=4) - final_nf = nf * 8 - - self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 512) - self.ref_linear = nn.Linear(nf * 8, 128) - - self.output_linears = nn.Sequential( - nn.Linear(128+512, 512), - self.lrelu, - nn.Linear(512, 256), - self.lrelu, - nn.Linear(256, 128), - self.lrelu, - nn.Linear(128, 1) - ) - - def forward(self, x, ref, ref_center_point): - ref = self.ref_head(ref) - ref_center_point = ref_center_point // 16 - from models.SwitchedResidualGenerator_arch import gather_2d - ref_vector = gather_2d(ref, ref_center_point) - ref_vector = self.ref_linear(ref_vector) - - fea = self.feature_head(x) - fea = fea.contiguous().view(fea.size(0), -1) - fea = self.lrelu(self.linear1(fea)) - - out = self.output_linears(torch.cat([fea, ref_vector], dim=1)) - return out - - -class PsnrApproximator(nn.Module): - # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. - def __init__(self, nf, input_img_factor=1): - super(PsnrApproximator, self).__init__() - - # [64, 128, 128] - self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) - self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) - self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True) - # [64, 64, 64] - self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) - self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) - self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) - self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) - # [128, 32, 32] - self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) - self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) - self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) - self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) - - # [64, 128, 128] - self.real_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) - self.real_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) - self.real_bn0_1 = nn.BatchNorm2d(nf, affine=True) - # [64, 64, 64] - self.real_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) - self.real_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) - self.real_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) - self.real_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) - # [128, 32, 32] - self.real_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) - self.real_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) - self.real_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) - self.real_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) - - # [512, 16, 16] - self.conv3_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) - self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) - self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) - # [512, 8, 8] - self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) - self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) - self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) - final_nf = nf * 8 - - # activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 1024) - self.linear2 = nn.Linear(1024, 512) - self.linear3 = nn.Linear(512, 128) - self.linear4 = nn.Linear(128, 1) - - def compute_body1(self, real): - fea = self.lrelu(self.real_conv0_0(real)) - fea = self.lrelu(self.real_bn0_1(self.real_conv0_1(fea))) - fea = self.lrelu(self.real_bn1_0(self.real_conv1_0(fea))) - fea = self.lrelu(self.real_bn1_1(self.real_conv1_1(fea))) - fea = self.lrelu(self.real_bn2_0(self.real_conv2_0(fea))) - fea = self.lrelu(self.real_bn2_1(self.real_conv2_1(fea))) - return fea - - def compute_body2(self, fake): - fea = self.lrelu(self.fake_conv0_0(fake)) - fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea))) - fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea))) - fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea))) - fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea))) - fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea))) - return fea - - def forward(self, real, fake): - real_fea = checkpoint(self.compute_body1, real) - fake_fea = checkpoint(self.compute_body2, fake) - fea = torch.cat([real_fea, fake_fea], dim=1) - - fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) - fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) - fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) - fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) - - fea = fea.contiguous().view(fea.size(0), -1) - fea = self.lrelu(self.linear1(fea)) - fea = self.lrelu(self.linear2(fea)) - fea = self.lrelu(self.linear3(fea)) - out = self.linear4(fea) - return out.squeeze() - - -class SingleImageQualityEstimator(nn.Module): - # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. - def __init__(self, nf, input_img_factor=1): - super(SingleImageQualityEstimator, self).__init__() - - # [64, 128, 128] - self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) - self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) - self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True) - # [64, 64, 64] - self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) - self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) - self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) - self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) - # [128, 32, 32] - self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) - self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) - self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) - self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) - - # [512, 16, 16] - self.conv3_0 = nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False) - self.bn3_0 = nn.BatchNorm2d(nf * 4, affine=True) - self.conv3_1 = nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False) - self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) - # [512, 8, 8] - self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True) - self.conv4_1 = nn.Conv2d(nf * 8, nf * 2, 3, 1, 1, bias=True) - self.conv4_2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) - self.conv4_3 = nn.Conv2d(nf, 3, 3, 1, 1, bias=True) - self.sigmoid = nn.Sigmoid() - self.lrelu = nn.LeakyReLU(negative_slope=.2, inplace=True) - - def compute_body(self, fake): - fea = self.lrelu(self.fake_conv0_0(fake)) - fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea))) - fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea))) - fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea))) - fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea))) - fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea))) - return fea - - def forward(self, fake): - fea = checkpoint(self.compute_body, fake) - fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) - fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) - fea = self.lrelu(self.conv4_0(fea)) - fea = self.lrelu(self.conv4_1(fea)) - fea = self.lrelu(self.conv4_2(fea)) - fea = self.sigmoid(self.conv4_3(fea)) - return fea +@register_model +def register_discriminator_vgg_128(opt_net, opt): + return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'], + extra_conv=opt_get(opt_net, ['extra_conv'], False), + do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False)) diff --git a/codes/models/stylegan/Discriminator_StyleGAN.py b/codes/models/stylegan/Discriminator_StyleGAN.py index cab5b278..5722ecd8 100644 --- a/codes/models/stylegan/Discriminator_StyleGAN.py +++ b/codes/models/stylegan/Discriminator_StyleGAN.py @@ -5,6 +5,9 @@ from torch import nn import torch.nn.functional as F import numpy as np +from trainer.networks import register_model +from utils.util import opt_get + class BlurLayer(nn.Module): def __init__(self, kernel=None, normalize=True, flip=False, stride=1): @@ -372,4 +375,9 @@ class StyleGanDiscriminator(nn.Module): else: raise KeyError("Unknown structure: ", self.structure) - return scores_out \ No newline at end of file + return scores_out + + +@register_model +def register_stylegan_vgg(opt_net, opt): + return StyleGanDiscriminator(opt_get(opt_net, ['image_size'], 128)) \ No newline at end of file diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/stylegan/stylegan2_lucidrains.py index 1310cb58..f61f2b0f 100644 --- a/codes/models/stylegan/stylegan2_lucidrains.py +++ b/codes/models/stylegan/stylegan2_lucidrains.py @@ -18,7 +18,7 @@ from torch.autograd import grad as torch_grad from vector_quantize_pytorch import VectorQuantize from trainer.networks import register_model -from utils.util import checkpoint +from utils.util import checkpoint, opt_get try: from apex import amp @@ -763,7 +763,7 @@ class DiscriminatorBlock(nn.Module): class StyleGan2Discriminator(nn.Module): def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], - transparent=False, fmap_max=512, input_filters=3): + transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False): super().__init__() num_layers = int(log2(image_size) - 1) @@ -789,12 +789,16 @@ class StyleGan2Discriminator(nn.Module): attn_blocks.append(attn_fn) - quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None - quantize_blocks.append(quantize_fn) + if quantize: + quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None + quantize_blocks.append(quantize_fn) + else: + quantize_blocks.append(None) self.blocks = nn.ModuleList(blocks) self.attn_blocks = nn.ModuleList(attn_blocks) self.quantize_blocks = nn.ModuleList(quantize_blocks) + self.do_checkpointing = do_checkpointing chan_last = filters[-1] latent_dim = 2 * 2 * chan_last @@ -811,7 +815,10 @@ class StyleGan2Discriminator(nn.Module): quantize_loss = torch.zeros(1).to(x) for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): - x = block(x) + if self.do_checkpointing: + x = checkpoint(block, x) + else: + x = block(x) if exists(attn_block): x = attn_block(x) @@ -862,7 +869,6 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss): # Apply gradient penalty. TODO: migrate this elsewhere. if self.env['step'] % self.gp_frequency == 0: - from models.stylegan.stylegan2_lucidrains import gradient_penalty gp = gradient_penalty(real_input, real) self.metrics.append(("gradient_penalty", gp.clone().detach())) divergence_loss = divergence_loss + gp @@ -877,17 +883,14 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss): self.w_styles = opt['w_styles'] self.gen = opt['gen'] self.pl_mean = None - from models.archs.stylegan.stylegan2_lucidrains import EMA self.pl_length_ma = EMA(.99) def forward(self, net, state): w_styles = state[self.w_styles] gen = state[self.gen] - from models.stylegan.stylegan2_lucidrains import calc_pl_lengths pl_lengths = calc_pl_lengths(w_styles, gen) avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) - from models.stylegan.stylegan2_lucidrains import is_empty if not is_empty(self.pl_mean): pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() if not torch.isnan(pl_loss): @@ -906,3 +909,12 @@ def register_stylegan2_lucidrains(opt_net, opt): return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], style_depth=opt_net['style_depth'], structure_input=is_structured, attn_layers=attn) + + +@register_model +def register_stylegan2_discriminator(opt_net, opt): + attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] + disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, + do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), + quantize=opt_get(opt_net, ['quantize'], False)) + return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) diff --git a/codes/requirements.txt b/codes/requirements.txt index 6684409f..b8b63f26 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -11,7 +11,7 @@ munch tqdm scp tensorboard -pytorch_fid +pytorch_fid==0.1.1 kornia linear_attention_transformer vector_quantize_pytorch diff --git a/codes/scripts/test_dataloader.py b/codes/scripts/test_dataloader.py deleted file mode 100644 index 642c6ef7..00000000 --- a/codes/scripts/test_dataloader.py +++ /dev/null @@ -1,104 +0,0 @@ -import sys -import os.path as osp -import math -import torchvision.utils - -sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) -from data import create_dataloader, create_dataset # noqa: E402 -from utils import util # noqa: E402 - - -def main(): - dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub - opt = {} - opt['dist'] = False - opt['gpu_ids'] = [0] - if dataset == 'REDS': - opt['name'] = 'test_REDS' - opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb' - opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' - opt['mode'] = 'REDS' - opt['N_frames'] = 5 - opt['phase'] = 'train' - opt['use_shuffle'] = True - opt['n_workers'] = 8 - opt['batch_size'] = 16 - opt['target_size'] = 256 - opt['LQ_size'] = 64 - opt['scale'] = 4 - opt['use_flip'] = True - opt['use_rot'] = True - opt['interval_list'] = [1] - opt['random_reverse'] = False - opt['border_mode'] = False - opt['cache_keys'] = None - opt['data_type'] = 'lmdb' # img | lmdb | mc - elif dataset == 'Vimeo90K': - opt['name'] = 'test_Vimeo90K' - opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' - opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' - opt['mode'] = 'Vimeo90K' - opt['N_frames'] = 7 - opt['phase'] = 'train' - opt['use_shuffle'] = True - opt['n_workers'] = 8 - opt['batch_size'] = 16 - opt['target_size'] = 256 - opt['LQ_size'] = 64 - opt['scale'] = 4 - opt['use_flip'] = True - opt['use_rot'] = True - opt['interval_list'] = [1] - opt['random_reverse'] = False - opt['border_mode'] = False - opt['cache_keys'] = None - opt['data_type'] = 'lmdb' # img | lmdb | mc - elif dataset == 'DIV2K800_sub': - opt['name'] = 'DIV2K800' - opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' - opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' - opt['mode'] = 'LQGT' - opt['phase'] = 'train' - opt['use_shuffle'] = True - opt['n_workers'] = 8 - opt['batch_size'] = 16 - opt['target_size'] = 128 - opt['scale'] = 4 - opt['use_flip'] = True - opt['use_rot'] = True - opt['color'] = 'RGB' - opt['data_type'] = 'lmdb' # img | lmdb - else: - raise ValueError('Please implement by yourself.') - - util.mkdir('tmp') - train_set = create_dataset(opt) - train_loader = create_dataloader(train_set, opt, opt, None) - nrow = int(math.sqrt(opt['batch_size'])) - padding = 2 if opt['phase'] == 'train' else 0 - - print('start...') - for i, data in enumerate(train_loader): - if i > 5: - break - print(i) - if dataset == 'REDS' or dataset == 'Vimeo90K': - LQs = data['LQs'] - else: - LQ = data['lq'] - GT = data['hq'] - - if dataset == 'REDS' or dataset == 'Vimeo90K': - for j in range(LQs.size(1)): - torchvision.utils.save_image(LQs[:, j, :, :, :], - 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow, - padding=padding, normalize=False) - else: - torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow, - padding=padding, normalize=False) - torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding, - normalize=False) - - -if __name__ == "__main__": - main() diff --git a/codes/train.py b/codes/train.py index 91c8f75f..5343abe0 100644 --- a/codes/train.py +++ b/codes/train.py @@ -293,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_styled_sr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index caab8373..1ce50c73 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -57,11 +57,11 @@ class ExtensibleTrainer(BaseModel): new_net = None if net['type'] == 'generator': if new_net is None: - new_net = networks.create_model(opt, net, opt['scale']).to(self.device) + new_net = networks.create_model(opt, net).to(self.device) self.netsG[name] = new_net elif net['type'] == 'discriminator': if new_net is None: - new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device) + new_net = networks.create_model(opt, net).to(self.device) self.netsD[name] = new_net else: raise NotImplementedError("Can only handle generators and discriminators") diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index 3a192f6f..0ef9beb9 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -1,18 +1,11 @@ -import functools import importlib import logging import pkgutil import sys from collections import OrderedDict from inspect import isfunction, getmembers - import torch -import torchvision - -import models.discriminator_vgg_arch as SRGAN_arch import models.feature_arch as feature_arch -import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch -from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator logger = logging.getLogger('base') @@ -63,7 +56,7 @@ class CreateModelError(Exception): f'{available}') -def create_model(opt, opt_net, scale=None): +def create_model(opt, opt_net): which_model = opt_net['which_model'] # For backwards compatibility. if not which_model: @@ -76,96 +69,6 @@ def create_model(opt, opt_net, scale=None): return registered_fns[which_model](opt_net, opt) -class GradDiscWrapper(torch.nn.Module): - def __init__(self, m): - super(GradDiscWrapper, self).__init__() - logger.info("Wrapping a discriminator..") - self.m = m - - def forward(self, x): - return self.m(x) - -def define_D_net(opt_net, img_sz=None, wrap=False): - which_model = opt_net['which_model_D'] - - if 'image_size' in opt_net.keys(): - img_sz = opt_net['image_size'] - - if which_model == 'discriminator_vgg_128': - netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv']) - elif which_model == 'discriminator_vgg_128_gn': - extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False - netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], - input_img_factor=img_sz / 128, extra_conv=extra_conv) - if wrap: - netD = GradDiscWrapper(netD) - elif which_model == 'discriminator_vgg_128_gn_checkpointed': - netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True) - elif which_model == 'stylegan_vgg': - netD = StyleGanDiscriminator(128) - elif which_model == 'discriminator_resnet': - netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) - elif which_model == 'discriminator_resnet_50': - netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) - elif which_model == 'resnext': - netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8)) - #state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True) - #netD.load_state_dict(state_dict, strict=False) - netD.fc = torch.nn.Linear(512 * 4, 1) - elif which_model == 'discriminator_pix': - netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) - elif which_model == "discriminator_unet": - netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) - elif which_model == "discriminator_unet_fea": - netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode']) - elif which_model == "discriminator_switched": - netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'], - final_temperature_step=opt_net['final_temperature_step']) - elif which_model == "cross_compare_vgg128": - netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale']) - elif which_model == "discriminator_refvgg": - netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) - elif which_model == "psnr_approximator": - netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128) - elif which_model == "stylegan2_discriminator": - attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] - from models.stylegan.stylegan2_lucidrains import StyleGan2Discriminator - disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) - from models.stylegan.stylegan2_lucidrains import StyleGan2Augmentor - netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) - else: - raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) - return netD - -# Discriminator -def define_D(opt, wrap=False): - img_sz = opt['datasets']['train']['target_size'] - opt_net = opt['network_D'] - return define_D_net(opt_net, img_sz, wrap=wrap) - -def define_fixed_D(opt): - # Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added. - net = define_D_net(opt) - - # Load the model parameters: - load_net = torch.load(opt['pretrained_path']) - load_net_clean = OrderedDict() # remove unnecessary 'module.' - for k, v in load_net.items(): - if k.startswith('module.'): - load_net_clean[k[7:]] = v - else: - load_net_clean[k] = v - net.load_state_dict(load_net_clean) - - # Put into eval mode, freeze the parameters and set the 'weight' field. - net.eval() - for k, v in net.named_parameters(): - v.requires_grad = False - net.fdisc_weight = opt['weight'] - - return net - - # Define network used for perceptual loss def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None): if which_model == 'vgg': diff --git a/codes/utils/distill_torchscript.py b/codes/utils/distill_torchscript.py deleted file mode 100644 index 819cb358..00000000 --- a/codes/utils/distill_torchscript.py +++ /dev/null @@ -1,177 +0,0 @@ -import argparse -import functools -import torch -from utils import options as option -from trainer.networks import create_model - - -class TracedModule: - def __init__(self, idname): - self.idname = idname - self.traced_outputs = [] - self.traced_inputs = [] - - -class TorchCustomTrace: - def __init__(self): - self.module_name_counter = {} - self.modules = {} - self.graph = {} - self.module_map_by_inputs = {} - self.module_map_by_outputs = {} - self.inputs_to_func_output_tuple = {} - - def add_tracked_module(self, mod: torch.nn.Module): - modname = type(mod).__name__ - if modname not in self.module_name_counter.keys(): - self.module_name_counter[modname] = 0 - self.module_name_counter[modname] += 1 - idname = "%s(%03d)" % (modname, self.module_name_counter[modname]) - self.modules[idname] = TracedModule(idname) - return idname - - # Only called for nn.Modules since those are the only things we can access. Filling in the gaps will be done in - # the backwards pass. - def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str): - mod = self.modules[mod_id] - ''' - for li in inputs: - if type(li) == torch.Tensor: - li = [li] - if type(li) == list: - for i in li: - if i.data_ptr() in self.module_map_by_inputs.keys(): - self.module_map_by_inputs[i.data_ptr()].append(mod) - else: - self.module_map_by_inputs[i.data_ptr()] = [mod] - for o in outputs: - if o.data_ptr() in self.module_map_by_inputs.keys(): - self.module_map_by_inputs[o.data_ptr()].append(mod) - else: - self.module_map_by_inputs[o.data_ptr()] = [mod] - ''' - print(trace) - - def mem_backward_hook(self, inputs, outputs, op): - if len(inputs) == 0: - print("No inputs.. %s" % (op,)) - outs = [o.data_ptr() for o in outputs] - tup = (outs, op) - #print(tup) - for li in inputs: - if type(li) == torch.Tensor: - li = [li] - if type(li) == list: - for i in li: - if i.data_ptr() in self.module_map_by_inputs.keys(): - print("%i: [%s] {%s}" % (i.data_ptr(), op, [n.idname for n in self.module_map_by_inputs[i.data_ptr()]])) - if i.data_ptr() in self.inputs_to_func_output_tuple.keys(): - self.inputs_to_func_output_tuple[i.data_ptr()].append(tup) - else: - self.inputs_to_func_output_tuple[i.data_ptr()] = [tup] - - def install_hooks(self, mod: torch.nn.Module, trace=""): - mod_id = self.add_tracked_module(mod) - my_trace = trace + "->" + mod_id - # If this module has parameters, it also has a state worth tracking. - #if next(mod.parameters(recurse=False), None) is not None: - mod.register_forward_hook(functools.partial(self.mem_forward_hook, trace=my_trace, mod_id=mod_id)) - - for m in mod.children(): - self.install_hooks(m, my_trace) - - def install_backward_hooks(self, grad_fn): - # AccumulateGrad simply pushes a gradient into the specified variable, and isn't useful for the purposes of - # tracing the graph. - if grad_fn is None or "AccumulateGrad" in str(grad_fn): - return - grad_fn.register_hook(functools.partial(self.mem_backward_hook, op=str(grad_fn))) - for g, _ in grad_fn.next_functions: - self.install_backward_hooks(g) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml') - opt = option.parse(parser.parse_args().opt, is_train=False) - opt = option.dict_to_nonedict(opt) - - netG = create_model(opt) - dummyInput = torch.rand(1,3,32,32) - - mode = 'onnx' - if mode == 'torchscript': - print("Tracing generator network..") - traced_netG = torch.jit.trace(netG, dummyInput) - traced_netG.save('../results/ts_generator.zip') - - print(traced_netG.code) - for i, module in enumerate(traced_netG.RRDB_trunk.modules()): - print(i, str(module)) - elif mode == 'onnx': - print("Performing onnx trace") - input_names = ["lr_input"] - output_names = ["hr_image"] - dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}} - - torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names, - output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12) - elif mode == 'memtrace': - criterion = torch.nn.MSELoss() - tracer = TorchCustomTrace() - tracer.install_hooks(netG) - out, = netG(dummyInput) - tracer.install_backward_hooks(out.grad_fn) - target = torch.zeros_like(out) - loss = criterion(out, target) - loss.backward() - elif mode == 'trace': - out = netG.forward(dummyInput)[0] - print(out.shape) - # Build the graph backwards. - graph = build_graph(out, 'output') - -def get_unique_id_for_fn(fn): - return (str(fn).split(" object at ")[1])[:-1] - -class GraphNode: - def __init__(self, fn): - self.name = (str(fn).split(" object at ")[0])[1:] - self.fn = fn - self.children = {} - self.parents = {} - - def add_parent(self, parent): - self.parents[get_unique_id_for_fn(parent)] = parent - - def add_child(self, child): - self.children[get_unique_id_for_fn(child)] = child - -class TorchGraph: - def __init__(self): - self.tensor_map = {} - - def get_node_for_tensor(self, t): - return self.tensor_map[get_unique_id_for_fn(t)] - - def init(self, output_tensor): - self.build_graph_backwards(output_tensor.grad_fn, None) - # Find inputs - self.inputs = [] - for v in self.tensor_map.values(): - # Is an input if the parents dict is empty. - if bool(v.parents): - self.inputs.append(v) - - def build_graph_backwards(self, fn, previous_fn): - id = get_unique_id_for_fn(fn) - if id in self.tensor_map: - node = self.tensor_map[id] - node.add_child(previous_fn) - else: - node = GraphNode(fn) - self.tensor_map[id] = node - # Propagate to children - for child_fn in fn.next_functions: - node.add_parent(self.build_graph_backwards(child_fn, fn)) - return node \ No newline at end of file