Move discriminators to the create_model paradigm

Also cleans up a lot of old discriminator models that I have no intention
of using again.
This commit is contained in:
James Betker 2021-01-01 15:56:09 -07:00
parent 7976a5825d
commit 193cdc6636
10 changed files with 54 additions and 908 deletions

View File

@ -37,13 +37,15 @@ class ByolDatasetWrapper(Dataset):
self.cropped_img_size = opt['crop_size']
self.key1 = opt_get(opt, ['key1'], 'hq')
self.key2 = opt_get(opt, ['key2'], 'lq')
for_sr = opt_get(opt, ['for_sr'], False) # When set, color alterations and blurs are disabled.
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
if not for_sr:
augmentations.extend([RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)])
if opt['normalize']:
# The paper calls for normalization. Most datasets/models in this repo don't use this.
# Recommend setting true if you want to train exactly like the paper.

View File

@ -3,7 +3,9 @@ import torch.nn as nn
from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
import torch.nn.functional as F
from utils.util import checkpoint
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class Discriminator_VGG_128(nn.Module):
@ -79,6 +81,12 @@ class Discriminator_VGG_128(nn.Module):
return out
@register_model
def register_discriminator_vgg_128(opt_net, opt):
return Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'] / 128,
extra_conv=opt_net['extra_conv'])
class Discriminator_VGG_128_GN(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False, extra_conv=False):
@ -159,514 +167,8 @@ class Discriminator_VGG_128_GN(nn.Module):
return out
class CrossCompareBlock(nn.Module):
def __init__(self, nf_in, nf_out):
super(CrossCompareBlock, self).__init__()
self.conv_hr_merge = ConvGnLelu(nf_in * 2, nf_in, kernel_size=1, bias=False, activation=False, norm=True)
self.proc_hr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
self.proc_lr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
self.reduce_hr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
self.reduce_lr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
def forward(self, hr, lr):
hr = self.conv_hr_merge(torch.cat([hr, lr], dim=1))
hr = self.proc_hr(hr)
hr = self.reduce_hr(hr)
lr = self.proc_lr(lr)
lr = self.reduce_lr(lr)
return hr, lr
class CrossCompareDiscriminator(nn.Module):
def __init__(self, in_nc, ref_channels, nf, scale=4):
super(CrossCompareDiscriminator, self).__init__()
assert scale == 2 or scale == 4
self.init_conv_hr = ConvGnLelu(in_nc, nf, stride=2, norm=False, bias=True, activation=True)
self.init_conv_lr = ConvGnLelu(ref_channels, nf, stride=1, norm=False, bias=True, activation=True)
if scale == 4:
strd_2 = 2
else:
strd_2 = 1
self.second_conv = ConvGnLelu(nf, nf, stride=strd_2, norm=True, bias=False, activation=True)
self.cross1 = CrossCompareBlock(nf, nf * 2)
self.cross2 = CrossCompareBlock(nf * 2, nf * 4)
self.cross3 = CrossCompareBlock(nf * 4, nf * 8)
self.cross4 = CrossCompareBlock(nf * 8, nf * 8)
self.fproc_conv = ConvGnLelu(nf * 8, nf, norm=True, bias=True, activation=True)
self.out_conv = ConvGnLelu(nf, 1, norm=False, bias=False, activation=False)
self.scale = scale * 16
def forward(self, hr, lr):
hr = self.init_conv_hr(hr)
hr = self.second_conv(hr)
lr = self.init_conv_lr(lr)
hr, lr = self.cross1(hr, lr)
hr, lr = self.cross2(hr, lr)
hr, lr = self.cross3(hr, lr)
hr, _ = self.cross4(hr, lr)
return self.out_conv(self.fproc_conv(hr)).view(-1, 1)
# Returns tuple of (number_output_channels, scale_of_output_reduction (1/n))
def pixgan_parameters(self):
return 3, self.scale
class Discriminator_VGG_PixLoss(nn.Module):
def __init__(self, in_nc, nf):
super(Discriminator_VGG_PixLoss, self).__init__()
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
# [64, 64, 64]
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
# [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False)
self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
# Pyramid network: upsample with residuals and produce losses at multiple resolutions.
self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, activation=False)
self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False)
self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False)
self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, activation=False)
self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False)
self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False)
self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, norm=False, activation=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x, flatten=True):
fea0 = self.lrelu(self.conv0_0(x))
fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0)))
fea1 = self.lrelu(self.bn1_0(self.conv1_0(fea0)))
fea1 = self.lrelu(self.bn1_1(self.conv1_1(fea1)))
fea2 = self.lrelu(self.bn2_0(self.conv2_0(fea1)))
fea2 = self.lrelu(self.bn2_1(self.conv2_1(fea2)))
fea3 = self.lrelu(self.bn3_0(self.conv3_0(fea2)))
fea3 = self.lrelu(self.bn3_1(self.conv3_1(fea3)))
fea4 = self.lrelu(self.bn4_0(self.conv4_0(fea3)))
fea4 = self.lrelu(self.bn4_1(self.conv4_1(fea4)))
loss = self.reduce_1(fea4)
# "Weight" all losses the same by interpolating them to the highest dimension.
loss = self.pix_loss_collapse(loss)
loss = F.interpolate(loss, scale_factor=4, mode="nearest")
# And the pyramid network!
dec3 = self.up3_decimate(F.interpolate(fea4, scale_factor=2, mode="nearest"))
dec3 = torch.cat([dec3, fea3], dim=1)
dec3 = self.up3_converge(dec3)
dec3 = self.up3_proc(dec3)
loss3 = self.up3_reduce(dec3)
loss3 = self.up3_pix(loss3)
loss3 = F.interpolate(loss3, scale_factor=2, mode="nearest")
dec2 = self.up2_decimate(F.interpolate(dec3, scale_factor=2, mode="nearest"))
dec2 = torch.cat([dec2, fea2], dim=1)
dec2 = self.up2_converge(dec2)
dec2 = self.up2_proc(dec2)
dec2 = self.up2_reduce(dec2)
loss2 = self.up2_pix(dec2)
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.
combined_losses = torch.cat([loss, loss3, loss2], dim=1)
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 3, 8
class Discriminator_UNet(nn.Module):
def __init__(self, in_nc, nf):
super(Discriminator_UNet, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
# [512, 8, 8]
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
def forward(self, x, flatten=True):
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)
fea1 = self.conv1_0(fea0)
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
fea4 = self.conv4_0(fea3)
fea4 = self.conv4_1(fea4)
# And the pyramid network!
u1 = self.up1(fea4, fea3)
loss1 = self.collapse1(self.proc1(u1))
u2 = self.up2(u1, fea2)
loss2 = self.collapse2(self.proc2(u2))
u3 = self.up3(u2, fea1)
loss3 = self.collapse3(self.proc3(u3))
res = loss3.shape[2:]
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
F.interpolate(loss2, scale_factor=2),
F.interpolate(loss3, scale_factor=1)], dim=1)
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 3, 4
class Discriminator_UNet_FeaOut(nn.Module):
def __init__(self, in_nc, nf, feature_mode=False):
super(Discriminator_UNet_FeaOut, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
# [512, 8, 8]
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.fea_proc = ConvGnLelu(nf * 8, nf * 8, bias=True, norm=False, activation=False)
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
self.feature_mode = feature_mode
def forward(self, x, output_feature_vector=False):
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)
fea1 = self.conv1_0(fea0)
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
fea4 = self.conv4_0(fea3)
fea4 = self.conv4_1(fea4)
# And the pyramid network!
u1 = self.up1(fea4, fea3)
loss1 = self.collapse1(self.proc1(u1))
fea_out = self.fea_proc(u1)
combined_losses = F.interpolate(loss1, scale_factor=4)
if output_feature_vector:
return combined_losses.view(-1, 1), fea_out
else:
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 1, 4
class Vgg128GnHead(nn.Module):
def __init__(self, in_nc, nf, depth=5):
super(Vgg128GnHead, self).__init__()
assert depth == 4 or depth == 5 # Nothing stopping others from being implemented, just not done yet.
self.depth = depth
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
# [64, 64, 64]
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
# [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
if depth > 4:
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
if self.depth > 4:
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
return fea
class RefDiscriminatorVgg128(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf, input_img_factor=1):
super(RefDiscriminatorVgg128, self).__init__()
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.feature_head = Vgg128GnHead(in_nc, nf)
self.ref_head = Vgg128GnHead(in_nc+1, nf, depth=4)
final_nf = nf * 8
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 512)
self.ref_linear = nn.Linear(nf * 8, 128)
self.output_linears = nn.Sequential(
nn.Linear(128+512, 512),
self.lrelu,
nn.Linear(512, 256),
self.lrelu,
nn.Linear(256, 128),
self.lrelu,
nn.Linear(128, 1)
)
def forward(self, x, ref, ref_center_point):
ref = self.ref_head(ref)
ref_center_point = ref_center_point // 16
from models.SwitchedResidualGenerator_arch import gather_2d
ref_vector = gather_2d(ref, ref_center_point)
ref_vector = self.ref_linear(ref_vector)
fea = self.feature_head(x)
fea = fea.contiguous().view(fea.size(0), -1)
fea = self.lrelu(self.linear1(fea))
out = self.output_linears(torch.cat([fea, ref_vector], dim=1))
return out
class PsnrApproximator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, nf, input_img_factor=1):
super(PsnrApproximator, self).__init__()
# [64, 128, 128]
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [64, 128, 128]
self.real_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.real_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.real_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.real_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.real_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.real_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.real_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.real_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.real_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.real_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.real_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [512, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
final_nf = nf * 8
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 1024)
self.linear2 = nn.Linear(1024, 512)
self.linear3 = nn.Linear(512, 128)
self.linear4 = nn.Linear(128, 1)
def compute_body1(self, real):
fea = self.lrelu(self.real_conv0_0(real))
fea = self.lrelu(self.real_bn0_1(self.real_conv0_1(fea)))
fea = self.lrelu(self.real_bn1_0(self.real_conv1_0(fea)))
fea = self.lrelu(self.real_bn1_1(self.real_conv1_1(fea)))
fea = self.lrelu(self.real_bn2_0(self.real_conv2_0(fea)))
fea = self.lrelu(self.real_bn2_1(self.real_conv2_1(fea)))
return fea
def compute_body2(self, fake):
fea = self.lrelu(self.fake_conv0_0(fake))
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
return fea
def forward(self, real, fake):
real_fea = checkpoint(self.compute_body1, real)
fake_fea = checkpoint(self.compute_body2, fake)
fea = torch.cat([real_fea, fake_fea], dim=1)
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
fea = fea.contiguous().view(fea.size(0), -1)
fea = self.lrelu(self.linear1(fea))
fea = self.lrelu(self.linear2(fea))
fea = self.lrelu(self.linear3(fea))
out = self.linear4(fea)
return out.squeeze()
class SingleImageQualityEstimator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, nf, input_img_factor=1):
super(SingleImageQualityEstimator, self).__init__()
# [64, 128, 128]
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [512, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.conv3_1 = nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 2, 3, 1, 1, bias=True)
self.conv4_2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
self.conv4_3 = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
self.sigmoid = nn.Sigmoid()
self.lrelu = nn.LeakyReLU(negative_slope=.2, inplace=True)
def compute_body(self, fake):
fea = self.lrelu(self.fake_conv0_0(fake))
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
return fea
def forward(self, fake):
fea = checkpoint(self.compute_body, fake)
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.conv4_0(fea))
fea = self.lrelu(self.conv4_1(fea))
fea = self.lrelu(self.conv4_2(fea))
fea = self.sigmoid(self.conv4_3(fea))
return fea
@register_model
def register_discriminator_vgg_128(opt_net, opt):
return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'],
extra_conv=opt_get(opt_net, ['extra_conv'], False),
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False))

View File

@ -5,6 +5,9 @@ from torch import nn
import torch.nn.functional as F
import numpy as np
from trainer.networks import register_model
from utils.util import opt_get
class BlurLayer(nn.Module):
def __init__(self, kernel=None, normalize=True, flip=False, stride=1):
@ -372,4 +375,9 @@ class StyleGanDiscriminator(nn.Module):
else:
raise KeyError("Unknown structure: ", self.structure)
return scores_out
return scores_out
@register_model
def register_stylegan_vgg(opt_net, opt):
return StyleGanDiscriminator(opt_get(opt_net, ['image_size'], 128))

View File

@ -18,7 +18,7 @@ from torch.autograd import grad as torch_grad
from vector_quantize_pytorch import VectorQuantize
from trainer.networks import register_model
from utils.util import checkpoint
from utils.util import checkpoint, opt_get
try:
from apex import amp
@ -763,7 +763,7 @@ class DiscriminatorBlock(nn.Module):
class StyleGan2Discriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3):
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False):
super().__init__()
num_layers = int(log2(image_size) - 1)
@ -789,12 +789,16 @@ class StyleGan2Discriminator(nn.Module):
attn_blocks.append(attn_fn)
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
quantize_blocks.append(quantize_fn)
if quantize:
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
quantize_blocks.append(quantize_fn)
else:
quantize_blocks.append(None)
self.blocks = nn.ModuleList(blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)
self.quantize_blocks = nn.ModuleList(quantize_blocks)
self.do_checkpointing = do_checkpointing
chan_last = filters[-1]
latent_dim = 2 * 2 * chan_last
@ -811,7 +815,10 @@ class StyleGan2Discriminator(nn.Module):
quantize_loss = torch.zeros(1).to(x)
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
x = block(x)
if self.do_checkpointing:
x = checkpoint(block, x)
else:
x = block(x)
if exists(attn_block):
x = attn_block(x)
@ -862,7 +869,6 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
from models.stylegan.stylegan2_lucidrains import gradient_penalty
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp
@ -877,17 +883,14 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
self.w_styles = opt['w_styles']
self.gen = opt['gen']
self.pl_mean = None
from models.archs.stylegan.stylegan2_lucidrains import EMA
self.pl_length_ma = EMA(.99)
def forward(self, net, state):
w_styles = state[self.w_styles]
gen = state[self.gen]
from models.stylegan.stylegan2_lucidrains import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.stylegan.stylegan2_lucidrains import is_empty
if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss):
@ -906,3 +909,12 @@ def register_stylegan2_lucidrains(opt_net, opt):
return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn)
@register_model
def register_stylegan2_discriminator(opt_net, opt):
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
quantize=opt_get(opt_net, ['quantize'], False))
return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])

View File

@ -11,7 +11,7 @@ munch
tqdm
scp
tensorboard
pytorch_fid
pytorch_fid==0.1.1
kornia
linear_attention_transformer
vector_quantize_pytorch

View File

@ -1,104 +0,0 @@
import sys
import os.path as osp
import math
import torchvision.utils
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from data import create_dataloader, create_dataset # noqa: E402
from utils import util # noqa: E402
def main():
dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
opt = {}
opt['dist'] = False
opt['gpu_ids'] = [0]
if dataset == 'REDS':
opt['name'] = 'test_REDS'
opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
opt['mode'] = 'REDS'
opt['N_frames'] = 5
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = None
opt['data_type'] = 'lmdb' # img | lmdb | mc
elif dataset == 'Vimeo90K':
opt['name'] = 'test_Vimeo90K'
opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
opt['mode'] = 'Vimeo90K'
opt['N_frames'] = 7
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = None
opt['data_type'] = 'lmdb' # img | lmdb | mc
elif dataset == 'DIV2K800_sub':
opt['name'] = 'DIV2K800'
opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
opt['mode'] = 'LQGT'
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 128
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['color'] = 'RGB'
opt['data_type'] = 'lmdb' # img | lmdb
else:
raise ValueError('Please implement by yourself.')
util.mkdir('tmp')
train_set = create_dataset(opt)
train_loader = create_dataloader(train_set, opt, opt, None)
nrow = int(math.sqrt(opt['batch_size']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(train_loader):
if i > 5:
break
print(i)
if dataset == 'REDS' or dataset == 'Vimeo90K':
LQs = data['LQs']
else:
LQ = data['lq']
GT = data['hq']
if dataset == 'REDS' or dataset == 'Vimeo90K':
for j in range(LQs.size(1)):
torchvision.utils.save_image(LQs[:, j, :, :, :],
'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
padding=padding, normalize=False)
else:
torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
padding=padding, normalize=False)
torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
normalize=False)
if __name__ == "__main__":
main()

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_styled_sr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -57,11 +57,11 @@ class ExtensibleTrainer(BaseModel):
new_net = None
if net['type'] == 'generator':
if new_net is None:
new_net = networks.create_model(opt, net, opt['scale']).to(self.device)
new_net = networks.create_model(opt, net).to(self.device)
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
if new_net is None:
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
new_net = networks.create_model(opt, net).to(self.device)
self.netsD[name] = new_net
else:
raise NotImplementedError("Can only handle generators and discriminators")

View File

@ -1,18 +1,11 @@
import functools
import importlib
import logging
import pkgutil
import sys
from collections import OrderedDict
from inspect import isfunction, getmembers
import torch
import torchvision
import models.discriminator_vgg_arch as SRGAN_arch
import models.feature_arch as feature_arch
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
logger = logging.getLogger('base')
@ -63,7 +56,7 @@ class CreateModelError(Exception):
f'{available}')
def create_model(opt, opt_net, scale=None):
def create_model(opt, opt_net):
which_model = opt_net['which_model']
# For backwards compatibility.
if not which_model:
@ -76,96 +69,6 @@ def create_model(opt, opt_net, scale=None):
return registered_fns[which_model](opt_net, opt)
class GradDiscWrapper(torch.nn.Module):
def __init__(self, m):
super(GradDiscWrapper, self).__init__()
logger.info("Wrapping a discriminator..")
self.m = m
def forward(self, x):
return self.m(x)
def define_D_net(opt_net, img_sz=None, wrap=False):
which_model = opt_net['which_model_D']
if 'image_size' in opt_net.keys():
img_sz = opt_net['image_size']
if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
elif which_model == 'discriminator_vgg_128_gn':
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
input_img_factor=img_sz / 128, extra_conv=extra_conv)
if wrap:
netD = GradDiscWrapper(netD)
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
elif which_model == 'stylegan_vgg':
netD = StyleGanDiscriminator(128)
elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_50':
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'resnext':
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
#netD.load_state_dict(state_dict, strict=False)
netD.fc = torch.nn.Linear(512 * 4, 1)
elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet":
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet_fea":
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
elif which_model == "discriminator_switched":
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
final_temperature_step=opt_net['final_temperature_step'])
elif which_model == "cross_compare_vgg128":
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
elif which_model == "discriminator_refvgg":
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "psnr_approximator":
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "stylegan2_discriminator":
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
from models.stylegan.stylegan2_lucidrains import StyleGan2Discriminator
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
from models.stylegan.stylegan2_lucidrains import StyleGan2Augmentor
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Discriminator
def define_D(opt, wrap=False):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
return define_D_net(opt_net, img_sz, wrap=wrap)
def define_fixed_D(opt):
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
net = define_D_net(opt)
# Load the model parameters:
load_net = torch.load(opt['pretrained_path'])
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
net.load_state_dict(load_net_clean)
# Put into eval mode, freeze the parameters and set the 'weight' field.
net.eval()
for k, v in net.named_parameters():
v.requires_grad = False
net.fdisc_weight = opt['weight']
return net
# Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
if which_model == 'vgg':

View File

@ -1,177 +0,0 @@
import argparse
import functools
import torch
from utils import options as option
from trainer.networks import create_model
class TracedModule:
def __init__(self, idname):
self.idname = idname
self.traced_outputs = []
self.traced_inputs = []
class TorchCustomTrace:
def __init__(self):
self.module_name_counter = {}
self.modules = {}
self.graph = {}
self.module_map_by_inputs = {}
self.module_map_by_outputs = {}
self.inputs_to_func_output_tuple = {}
def add_tracked_module(self, mod: torch.nn.Module):
modname = type(mod).__name__
if modname not in self.module_name_counter.keys():
self.module_name_counter[modname] = 0
self.module_name_counter[modname] += 1
idname = "%s(%03d)" % (modname, self.module_name_counter[modname])
self.modules[idname] = TracedModule(idname)
return idname
# Only called for nn.Modules since those are the only things we can access. Filling in the gaps will be done in
# the backwards pass.
def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str):
mod = self.modules[mod_id]
'''
for li in inputs:
if type(li) == torch.Tensor:
li = [li]
if type(li) == list:
for i in li:
if i.data_ptr() in self.module_map_by_inputs.keys():
self.module_map_by_inputs[i.data_ptr()].append(mod)
else:
self.module_map_by_inputs[i.data_ptr()] = [mod]
for o in outputs:
if o.data_ptr() in self.module_map_by_inputs.keys():
self.module_map_by_inputs[o.data_ptr()].append(mod)
else:
self.module_map_by_inputs[o.data_ptr()] = [mod]
'''
print(trace)
def mem_backward_hook(self, inputs, outputs, op):
if len(inputs) == 0:
print("No inputs.. %s" % (op,))
outs = [o.data_ptr() for o in outputs]
tup = (outs, op)
#print(tup)
for li in inputs:
if type(li) == torch.Tensor:
li = [li]
if type(li) == list:
for i in li:
if i.data_ptr() in self.module_map_by_inputs.keys():
print("%i: [%s] {%s}" % (i.data_ptr(), op, [n.idname for n in self.module_map_by_inputs[i.data_ptr()]]))
if i.data_ptr() in self.inputs_to_func_output_tuple.keys():
self.inputs_to_func_output_tuple[i.data_ptr()].append(tup)
else:
self.inputs_to_func_output_tuple[i.data_ptr()] = [tup]
def install_hooks(self, mod: torch.nn.Module, trace=""):
mod_id = self.add_tracked_module(mod)
my_trace = trace + "->" + mod_id
# If this module has parameters, it also has a state worth tracking.
#if next(mod.parameters(recurse=False), None) is not None:
mod.register_forward_hook(functools.partial(self.mem_forward_hook, trace=my_trace, mod_id=mod_id))
for m in mod.children():
self.install_hooks(m, my_trace)
def install_backward_hooks(self, grad_fn):
# AccumulateGrad simply pushes a gradient into the specified variable, and isn't useful for the purposes of
# tracing the graph.
if grad_fn is None or "AccumulateGrad" in str(grad_fn):
return
grad_fn.register_hook(functools.partial(self.mem_backward_hook, op=str(grad_fn)))
for g, _ in grad_fn.next_functions:
self.install_backward_hooks(g)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
netG = create_model(opt)
dummyInput = torch.rand(1,3,32,32)
mode = 'onnx'
if mode == 'torchscript':
print("Tracing generator network..")
traced_netG = torch.jit.trace(netG, dummyInput)
traced_netG.save('../results/ts_generator.zip')
print(traced_netG.code)
for i, module in enumerate(traced_netG.RRDB_trunk.modules()):
print(i, str(module))
elif mode == 'onnx':
print("Performing onnx trace")
input_names = ["lr_input"]
output_names = ["hr_image"]
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12)
elif mode == 'memtrace':
criterion = torch.nn.MSELoss()
tracer = TorchCustomTrace()
tracer.install_hooks(netG)
out, = netG(dummyInput)
tracer.install_backward_hooks(out.grad_fn)
target = torch.zeros_like(out)
loss = criterion(out, target)
loss.backward()
elif mode == 'trace':
out = netG.forward(dummyInput)[0]
print(out.shape)
# Build the graph backwards.
graph = build_graph(out, 'output')
def get_unique_id_for_fn(fn):
return (str(fn).split(" object at ")[1])[:-1]
class GraphNode:
def __init__(self, fn):
self.name = (str(fn).split(" object at ")[0])[1:]
self.fn = fn
self.children = {}
self.parents = {}
def add_parent(self, parent):
self.parents[get_unique_id_for_fn(parent)] = parent
def add_child(self, child):
self.children[get_unique_id_for_fn(child)] = child
class TorchGraph:
def __init__(self):
self.tensor_map = {}
def get_node_for_tensor(self, t):
return self.tensor_map[get_unique_id_for_fn(t)]
def init(self, output_tensor):
self.build_graph_backwards(output_tensor.grad_fn, None)
# Find inputs
self.inputs = []
for v in self.tensor_map.values():
# Is an input if the parents dict is empty.
if bool(v.parents):
self.inputs.append(v)
def build_graph_backwards(self, fn, previous_fn):
id = get_unique_id_for_fn(fn)
if id in self.tensor_map:
node = self.tensor_map[id]
node.add_child(previous_fn)
else:
node = GraphNode(fn)
self.tensor_map[id] = node
# Propagate to children
for child_fn in fn.next_functions:
node.add_parent(self.build_graph_backwards(child_fn, fn))
return node