Move discriminators to the create_model paradigm

Also cleans up a lot of old discriminator models that I have no intention
of using again.
This commit is contained in:
James Betker 2021-01-01 15:56:09 -07:00
parent 7976a5825d
commit 193cdc6636
10 changed files with 54 additions and 908 deletions

View File

@ -37,13 +37,15 @@ class ByolDatasetWrapper(Dataset):
self.cropped_img_size = opt['crop_size'] self.cropped_img_size = opt['crop_size']
self.key1 = opt_get(opt, ['key1'], 'hq') self.key1 = opt_get(opt, ['key1'], 'hq')
self.key2 = opt_get(opt, ['key2'], 'lq') self.key2 = opt_get(opt, ['key2'], 'lq')
for_sr = opt_get(opt, ['for_sr'], False) # When set, color alterations and blurs are disabled.
augmentations = [ \ augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(), augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
if not for_sr:
augmentations.extend([RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)])
if opt['normalize']: if opt['normalize']:
# The paper calls for normalization. Most datasets/models in this repo don't use this. # The paper calls for normalization. Most datasets/models in this repo don't use this.
# Recommend setting true if you want to train exactly like the paper. # Recommend setting true if you want to train exactly like the paper.

View File

@ -3,7 +3,9 @@ import torch.nn as nn
from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
import torch.nn.functional as F import torch.nn.functional as F
from utils.util import checkpoint
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class Discriminator_VGG_128(nn.Module): class Discriminator_VGG_128(nn.Module):
@ -79,6 +81,12 @@ class Discriminator_VGG_128(nn.Module):
return out return out
@register_model
def register_discriminator_vgg_128(opt_net, opt):
return Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'] / 128,
extra_conv=opt_net['extra_conv'])
class Discriminator_VGG_128_GN(nn.Module): class Discriminator_VGG_128_GN(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False, extra_conv=False): def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False, extra_conv=False):
@ -159,514 +167,8 @@ class Discriminator_VGG_128_GN(nn.Module):
return out return out
class CrossCompareBlock(nn.Module): @register_model
def __init__(self, nf_in, nf_out): def register_discriminator_vgg_128(opt_net, opt):
super(CrossCompareBlock, self).__init__() return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'],
self.conv_hr_merge = ConvGnLelu(nf_in * 2, nf_in, kernel_size=1, bias=False, activation=False, norm=True) extra_conv=opt_get(opt_net, ['extra_conv'], False),
self.proc_hr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True) do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False))
self.proc_lr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
self.reduce_hr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
self.reduce_lr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
def forward(self, hr, lr):
hr = self.conv_hr_merge(torch.cat([hr, lr], dim=1))
hr = self.proc_hr(hr)
hr = self.reduce_hr(hr)
lr = self.proc_lr(lr)
lr = self.reduce_lr(lr)
return hr, lr
class CrossCompareDiscriminator(nn.Module):
def __init__(self, in_nc, ref_channels, nf, scale=4):
super(CrossCompareDiscriminator, self).__init__()
assert scale == 2 or scale == 4
self.init_conv_hr = ConvGnLelu(in_nc, nf, stride=2, norm=False, bias=True, activation=True)
self.init_conv_lr = ConvGnLelu(ref_channels, nf, stride=1, norm=False, bias=True, activation=True)
if scale == 4:
strd_2 = 2
else:
strd_2 = 1
self.second_conv = ConvGnLelu(nf, nf, stride=strd_2, norm=True, bias=False, activation=True)
self.cross1 = CrossCompareBlock(nf, nf * 2)
self.cross2 = CrossCompareBlock(nf * 2, nf * 4)
self.cross3 = CrossCompareBlock(nf * 4, nf * 8)
self.cross4 = CrossCompareBlock(nf * 8, nf * 8)
self.fproc_conv = ConvGnLelu(nf * 8, nf, norm=True, bias=True, activation=True)
self.out_conv = ConvGnLelu(nf, 1, norm=False, bias=False, activation=False)
self.scale = scale * 16
def forward(self, hr, lr):
hr = self.init_conv_hr(hr)
hr = self.second_conv(hr)
lr = self.init_conv_lr(lr)
hr, lr = self.cross1(hr, lr)
hr, lr = self.cross2(hr, lr)
hr, lr = self.cross3(hr, lr)
hr, _ = self.cross4(hr, lr)
return self.out_conv(self.fproc_conv(hr)).view(-1, 1)
# Returns tuple of (number_output_channels, scale_of_output_reduction (1/n))
def pixgan_parameters(self):
return 3, self.scale
class Discriminator_VGG_PixLoss(nn.Module):
def __init__(self, in_nc, nf):
super(Discriminator_VGG_PixLoss, self).__init__()
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
# [64, 64, 64]
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
# [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False)
self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
# Pyramid network: upsample with residuals and produce losses at multiple resolutions.
self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, activation=False)
self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False)
self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False)
self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, activation=False)
self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False)
self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False)
self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, norm=False, activation=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x, flatten=True):
fea0 = self.lrelu(self.conv0_0(x))
fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0)))
fea1 = self.lrelu(self.bn1_0(self.conv1_0(fea0)))
fea1 = self.lrelu(self.bn1_1(self.conv1_1(fea1)))
fea2 = self.lrelu(self.bn2_0(self.conv2_0(fea1)))
fea2 = self.lrelu(self.bn2_1(self.conv2_1(fea2)))
fea3 = self.lrelu(self.bn3_0(self.conv3_0(fea2)))
fea3 = self.lrelu(self.bn3_1(self.conv3_1(fea3)))
fea4 = self.lrelu(self.bn4_0(self.conv4_0(fea3)))
fea4 = self.lrelu(self.bn4_1(self.conv4_1(fea4)))
loss = self.reduce_1(fea4)
# "Weight" all losses the same by interpolating them to the highest dimension.
loss = self.pix_loss_collapse(loss)
loss = F.interpolate(loss, scale_factor=4, mode="nearest")
# And the pyramid network!
dec3 = self.up3_decimate(F.interpolate(fea4, scale_factor=2, mode="nearest"))
dec3 = torch.cat([dec3, fea3], dim=1)
dec3 = self.up3_converge(dec3)
dec3 = self.up3_proc(dec3)
loss3 = self.up3_reduce(dec3)
loss3 = self.up3_pix(loss3)
loss3 = F.interpolate(loss3, scale_factor=2, mode="nearest")
dec2 = self.up2_decimate(F.interpolate(dec3, scale_factor=2, mode="nearest"))
dec2 = torch.cat([dec2, fea2], dim=1)
dec2 = self.up2_converge(dec2)
dec2 = self.up2_proc(dec2)
dec2 = self.up2_reduce(dec2)
loss2 = self.up2_pix(dec2)
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.
combined_losses = torch.cat([loss, loss3, loss2], dim=1)
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 3, 8
class Discriminator_UNet(nn.Module):
def __init__(self, in_nc, nf):
super(Discriminator_UNet, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
# [512, 8, 8]
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
def forward(self, x, flatten=True):
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)
fea1 = self.conv1_0(fea0)
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
fea4 = self.conv4_0(fea3)
fea4 = self.conv4_1(fea4)
# And the pyramid network!
u1 = self.up1(fea4, fea3)
loss1 = self.collapse1(self.proc1(u1))
u2 = self.up2(u1, fea2)
loss2 = self.collapse2(self.proc2(u2))
u3 = self.up3(u2, fea1)
loss3 = self.collapse3(self.proc3(u3))
res = loss3.shape[2:]
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
F.interpolate(loss2, scale_factor=2),
F.interpolate(loss3, scale_factor=1)], dim=1)
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 3, 4
class Discriminator_UNet_FeaOut(nn.Module):
def __init__(self, in_nc, nf, feature_mode=False):
super(Discriminator_UNet_FeaOut, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
# [512, 8, 8]
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.fea_proc = ConvGnLelu(nf * 8, nf * 8, bias=True, norm=False, activation=False)
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
self.feature_mode = feature_mode
def forward(self, x, output_feature_vector=False):
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)
fea1 = self.conv1_0(fea0)
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
fea4 = self.conv4_0(fea3)
fea4 = self.conv4_1(fea4)
# And the pyramid network!
u1 = self.up1(fea4, fea3)
loss1 = self.collapse1(self.proc1(u1))
fea_out = self.fea_proc(u1)
combined_losses = F.interpolate(loss1, scale_factor=4)
if output_feature_vector:
return combined_losses.view(-1, 1), fea_out
else:
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 1, 4
class Vgg128GnHead(nn.Module):
def __init__(self, in_nc, nf, depth=5):
super(Vgg128GnHead, self).__init__()
assert depth == 4 or depth == 5 # Nothing stopping others from being implemented, just not done yet.
self.depth = depth
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
# [64, 64, 64]
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
# [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
if depth > 4:
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
if self.depth > 4:
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
return fea
class RefDiscriminatorVgg128(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf, input_img_factor=1):
super(RefDiscriminatorVgg128, self).__init__()
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.feature_head = Vgg128GnHead(in_nc, nf)
self.ref_head = Vgg128GnHead(in_nc+1, nf, depth=4)
final_nf = nf * 8
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 512)
self.ref_linear = nn.Linear(nf * 8, 128)
self.output_linears = nn.Sequential(
nn.Linear(128+512, 512),
self.lrelu,
nn.Linear(512, 256),
self.lrelu,
nn.Linear(256, 128),
self.lrelu,
nn.Linear(128, 1)
)
def forward(self, x, ref, ref_center_point):
ref = self.ref_head(ref)
ref_center_point = ref_center_point // 16
from models.SwitchedResidualGenerator_arch import gather_2d
ref_vector = gather_2d(ref, ref_center_point)
ref_vector = self.ref_linear(ref_vector)
fea = self.feature_head(x)
fea = fea.contiguous().view(fea.size(0), -1)
fea = self.lrelu(self.linear1(fea))
out = self.output_linears(torch.cat([fea, ref_vector], dim=1))
return out
class PsnrApproximator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, nf, input_img_factor=1):
super(PsnrApproximator, self).__init__()
# [64, 128, 128]
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [64, 128, 128]
self.real_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.real_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.real_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.real_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.real_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.real_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.real_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.real_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.real_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.real_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.real_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [512, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
final_nf = nf * 8
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 1024)
self.linear2 = nn.Linear(1024, 512)
self.linear3 = nn.Linear(512, 128)
self.linear4 = nn.Linear(128, 1)
def compute_body1(self, real):
fea = self.lrelu(self.real_conv0_0(real))
fea = self.lrelu(self.real_bn0_1(self.real_conv0_1(fea)))
fea = self.lrelu(self.real_bn1_0(self.real_conv1_0(fea)))
fea = self.lrelu(self.real_bn1_1(self.real_conv1_1(fea)))
fea = self.lrelu(self.real_bn2_0(self.real_conv2_0(fea)))
fea = self.lrelu(self.real_bn2_1(self.real_conv2_1(fea)))
return fea
def compute_body2(self, fake):
fea = self.lrelu(self.fake_conv0_0(fake))
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
return fea
def forward(self, real, fake):
real_fea = checkpoint(self.compute_body1, real)
fake_fea = checkpoint(self.compute_body2, fake)
fea = torch.cat([real_fea, fake_fea], dim=1)
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
fea = fea.contiguous().view(fea.size(0), -1)
fea = self.lrelu(self.linear1(fea))
fea = self.lrelu(self.linear2(fea))
fea = self.lrelu(self.linear3(fea))
out = self.linear4(fea)
return out.squeeze()
class SingleImageQualityEstimator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, nf, input_img_factor=1):
super(SingleImageQualityEstimator, self).__init__()
# [64, 128, 128]
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [512, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.conv3_1 = nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 2, 3, 1, 1, bias=True)
self.conv4_2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
self.conv4_3 = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
self.sigmoid = nn.Sigmoid()
self.lrelu = nn.LeakyReLU(negative_slope=.2, inplace=True)
def compute_body(self, fake):
fea = self.lrelu(self.fake_conv0_0(fake))
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
return fea
def forward(self, fake):
fea = checkpoint(self.compute_body, fake)
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.conv4_0(fea))
fea = self.lrelu(self.conv4_1(fea))
fea = self.lrelu(self.conv4_2(fea))
fea = self.sigmoid(self.conv4_3(fea))
return fea

View File

@ -5,6 +5,9 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from trainer.networks import register_model
from utils.util import opt_get
class BlurLayer(nn.Module): class BlurLayer(nn.Module):
def __init__(self, kernel=None, normalize=True, flip=False, stride=1): def __init__(self, kernel=None, normalize=True, flip=False, stride=1):
@ -372,4 +375,9 @@ class StyleGanDiscriminator(nn.Module):
else: else:
raise KeyError("Unknown structure: ", self.structure) raise KeyError("Unknown structure: ", self.structure)
return scores_out return scores_out
@register_model
def register_stylegan_vgg(opt_net, opt):
return StyleGanDiscriminator(opt_get(opt_net, ['image_size'], 128))

View File

@ -18,7 +18,7 @@ from torch.autograd import grad as torch_grad
from vector_quantize_pytorch import VectorQuantize from vector_quantize_pytorch import VectorQuantize
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint, opt_get
try: try:
from apex import amp from apex import amp
@ -763,7 +763,7 @@ class DiscriminatorBlock(nn.Module):
class StyleGan2Discriminator(nn.Module): class StyleGan2Discriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3): transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False):
super().__init__() super().__init__()
num_layers = int(log2(image_size) - 1) num_layers = int(log2(image_size) - 1)
@ -789,12 +789,16 @@ class StyleGan2Discriminator(nn.Module):
attn_blocks.append(attn_fn) attn_blocks.append(attn_fn)
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None if quantize:
quantize_blocks.append(quantize_fn) quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
quantize_blocks.append(quantize_fn)
else:
quantize_blocks.append(None)
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
self.attn_blocks = nn.ModuleList(attn_blocks) self.attn_blocks = nn.ModuleList(attn_blocks)
self.quantize_blocks = nn.ModuleList(quantize_blocks) self.quantize_blocks = nn.ModuleList(quantize_blocks)
self.do_checkpointing = do_checkpointing
chan_last = filters[-1] chan_last = filters[-1]
latent_dim = 2 * 2 * chan_last latent_dim = 2 * 2 * chan_last
@ -811,7 +815,10 @@ class StyleGan2Discriminator(nn.Module):
quantize_loss = torch.zeros(1).to(x) quantize_loss = torch.zeros(1).to(x)
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
x = block(x) if self.do_checkpointing:
x = checkpoint(block, x)
else:
x = block(x)
if exists(attn_block): if exists(attn_block):
x = attn_block(x) x = attn_block(x)
@ -862,7 +869,6 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
# Apply gradient penalty. TODO: migrate this elsewhere. # Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0: if self.env['step'] % self.gp_frequency == 0:
from models.stylegan.stylegan2_lucidrains import gradient_penalty
gp = gradient_penalty(real_input, real) gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach())) self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp divergence_loss = divergence_loss + gp
@ -877,17 +883,14 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
self.w_styles = opt['w_styles'] self.w_styles = opt['w_styles']
self.gen = opt['gen'] self.gen = opt['gen']
self.pl_mean = None self.pl_mean = None
from models.archs.stylegan.stylegan2_lucidrains import EMA
self.pl_length_ma = EMA(.99) self.pl_length_ma = EMA(.99)
def forward(self, net, state): def forward(self, net, state):
w_styles = state[self.w_styles] w_styles = state[self.w_styles]
gen = state[self.gen] gen = state[self.gen]
from models.stylegan.stylegan2_lucidrains import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen) pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.stylegan.stylegan2_lucidrains import is_empty
if not is_empty(self.pl_mean): if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss): if not torch.isnan(pl_loss):
@ -906,3 +909,12 @@ def register_stylegan2_lucidrains(opt_net, opt):
return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured, style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn) attn_layers=attn)
@register_model
def register_stylegan2_discriminator(opt_net, opt):
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
quantize=opt_get(opt_net, ['quantize'], False))
return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])

View File

@ -11,7 +11,7 @@ munch
tqdm tqdm
scp scp
tensorboard tensorboard
pytorch_fid pytorch_fid==0.1.1
kornia kornia
linear_attention_transformer linear_attention_transformer
vector_quantize_pytorch vector_quantize_pytorch

View File

@ -1,104 +0,0 @@
import sys
import os.path as osp
import math
import torchvision.utils
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from data import create_dataloader, create_dataset # noqa: E402
from utils import util # noqa: E402
def main():
dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
opt = {}
opt['dist'] = False
opt['gpu_ids'] = [0]
if dataset == 'REDS':
opt['name'] = 'test_REDS'
opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
opt['mode'] = 'REDS'
opt['N_frames'] = 5
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = None
opt['data_type'] = 'lmdb' # img | lmdb | mc
elif dataset == 'Vimeo90K':
opt['name'] = 'test_Vimeo90K'
opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
opt['mode'] = 'Vimeo90K'
opt['N_frames'] = 7
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = None
opt['data_type'] = 'lmdb' # img | lmdb | mc
elif dataset == 'DIV2K800_sub':
opt['name'] = 'DIV2K800'
opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
opt['mode'] = 'LQGT'
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 128
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['color'] = 'RGB'
opt['data_type'] = 'lmdb' # img | lmdb
else:
raise ValueError('Please implement by yourself.')
util.mkdir('tmp')
train_set = create_dataset(opt)
train_loader = create_dataloader(train_set, opt, opt, None)
nrow = int(math.sqrt(opt['batch_size']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(train_loader):
if i > 5:
break
print(i)
if dataset == 'REDS' or dataset == 'Vimeo90K':
LQs = data['LQs']
else:
LQ = data['lq']
GT = data['hq']
if dataset == 'REDS' or dataset == 'Vimeo90K':
for j in range(LQs.size(1)):
torchvision.utils.save_image(LQs[:, j, :, :, :],
'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
padding=padding, normalize=False)
else:
torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
padding=padding, normalize=False)
torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
normalize=False)
if __name__ == "__main__":
main()

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_styled_sr.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -57,11 +57,11 @@ class ExtensibleTrainer(BaseModel):
new_net = None new_net = None
if net['type'] == 'generator': if net['type'] == 'generator':
if new_net is None: if new_net is None:
new_net = networks.create_model(opt, net, opt['scale']).to(self.device) new_net = networks.create_model(opt, net).to(self.device)
self.netsG[name] = new_net self.netsG[name] = new_net
elif net['type'] == 'discriminator': elif net['type'] == 'discriminator':
if new_net is None: if new_net is None:
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device) new_net = networks.create_model(opt, net).to(self.device)
self.netsD[name] = new_net self.netsD[name] = new_net
else: else:
raise NotImplementedError("Can only handle generators and discriminators") raise NotImplementedError("Can only handle generators and discriminators")

View File

@ -1,18 +1,11 @@
import functools
import importlib import importlib
import logging import logging
import pkgutil import pkgutil
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from inspect import isfunction, getmembers from inspect import isfunction, getmembers
import torch import torch
import torchvision
import models.discriminator_vgg_arch as SRGAN_arch
import models.feature_arch as feature_arch import models.feature_arch as feature_arch
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -63,7 +56,7 @@ class CreateModelError(Exception):
f'{available}') f'{available}')
def create_model(opt, opt_net, scale=None): def create_model(opt, opt_net):
which_model = opt_net['which_model'] which_model = opt_net['which_model']
# For backwards compatibility. # For backwards compatibility.
if not which_model: if not which_model:
@ -76,96 +69,6 @@ def create_model(opt, opt_net, scale=None):
return registered_fns[which_model](opt_net, opt) return registered_fns[which_model](opt_net, opt)
class GradDiscWrapper(torch.nn.Module):
def __init__(self, m):
super(GradDiscWrapper, self).__init__()
logger.info("Wrapping a discriminator..")
self.m = m
def forward(self, x):
return self.m(x)
def define_D_net(opt_net, img_sz=None, wrap=False):
which_model = opt_net['which_model_D']
if 'image_size' in opt_net.keys():
img_sz = opt_net['image_size']
if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
elif which_model == 'discriminator_vgg_128_gn':
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
input_img_factor=img_sz / 128, extra_conv=extra_conv)
if wrap:
netD = GradDiscWrapper(netD)
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
elif which_model == 'stylegan_vgg':
netD = StyleGanDiscriminator(128)
elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_50':
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'resnext':
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
#netD.load_state_dict(state_dict, strict=False)
netD.fc = torch.nn.Linear(512 * 4, 1)
elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet":
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet_fea":
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
elif which_model == "discriminator_switched":
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
final_temperature_step=opt_net['final_temperature_step'])
elif which_model == "cross_compare_vgg128":
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
elif which_model == "discriminator_refvgg":
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "psnr_approximator":
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "stylegan2_discriminator":
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
from models.stylegan.stylegan2_lucidrains import StyleGan2Discriminator
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
from models.stylegan.stylegan2_lucidrains import StyleGan2Augmentor
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Discriminator
def define_D(opt, wrap=False):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
return define_D_net(opt_net, img_sz, wrap=wrap)
def define_fixed_D(opt):
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
net = define_D_net(opt)
# Load the model parameters:
load_net = torch.load(opt['pretrained_path'])
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
net.load_state_dict(load_net_clean)
# Put into eval mode, freeze the parameters and set the 'weight' field.
net.eval()
for k, v in net.named_parameters():
v.requires_grad = False
net.fdisc_weight = opt['weight']
return net
# Define network used for perceptual loss # Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None): def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
if which_model == 'vgg': if which_model == 'vgg':

View File

@ -1,177 +0,0 @@
import argparse
import functools
import torch
from utils import options as option
from trainer.networks import create_model
class TracedModule:
def __init__(self, idname):
self.idname = idname
self.traced_outputs = []
self.traced_inputs = []
class TorchCustomTrace:
def __init__(self):
self.module_name_counter = {}
self.modules = {}
self.graph = {}
self.module_map_by_inputs = {}
self.module_map_by_outputs = {}
self.inputs_to_func_output_tuple = {}
def add_tracked_module(self, mod: torch.nn.Module):
modname = type(mod).__name__
if modname not in self.module_name_counter.keys():
self.module_name_counter[modname] = 0
self.module_name_counter[modname] += 1
idname = "%s(%03d)" % (modname, self.module_name_counter[modname])
self.modules[idname] = TracedModule(idname)
return idname
# Only called for nn.Modules since those are the only things we can access. Filling in the gaps will be done in
# the backwards pass.
def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str):
mod = self.modules[mod_id]
'''
for li in inputs:
if type(li) == torch.Tensor:
li = [li]
if type(li) == list:
for i in li:
if i.data_ptr() in self.module_map_by_inputs.keys():
self.module_map_by_inputs[i.data_ptr()].append(mod)
else:
self.module_map_by_inputs[i.data_ptr()] = [mod]
for o in outputs:
if o.data_ptr() in self.module_map_by_inputs.keys():
self.module_map_by_inputs[o.data_ptr()].append(mod)
else:
self.module_map_by_inputs[o.data_ptr()] = [mod]
'''
print(trace)
def mem_backward_hook(self, inputs, outputs, op):
if len(inputs) == 0:
print("No inputs.. %s" % (op,))
outs = [o.data_ptr() for o in outputs]
tup = (outs, op)
#print(tup)
for li in inputs:
if type(li) == torch.Tensor:
li = [li]
if type(li) == list:
for i in li:
if i.data_ptr() in self.module_map_by_inputs.keys():
print("%i: [%s] {%s}" % (i.data_ptr(), op, [n.idname for n in self.module_map_by_inputs[i.data_ptr()]]))
if i.data_ptr() in self.inputs_to_func_output_tuple.keys():
self.inputs_to_func_output_tuple[i.data_ptr()].append(tup)
else:
self.inputs_to_func_output_tuple[i.data_ptr()] = [tup]
def install_hooks(self, mod: torch.nn.Module, trace=""):
mod_id = self.add_tracked_module(mod)
my_trace = trace + "->" + mod_id
# If this module has parameters, it also has a state worth tracking.
#if next(mod.parameters(recurse=False), None) is not None:
mod.register_forward_hook(functools.partial(self.mem_forward_hook, trace=my_trace, mod_id=mod_id))
for m in mod.children():
self.install_hooks(m, my_trace)
def install_backward_hooks(self, grad_fn):
# AccumulateGrad simply pushes a gradient into the specified variable, and isn't useful for the purposes of
# tracing the graph.
if grad_fn is None or "AccumulateGrad" in str(grad_fn):
return
grad_fn.register_hook(functools.partial(self.mem_backward_hook, op=str(grad_fn)))
for g, _ in grad_fn.next_functions:
self.install_backward_hooks(g)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
netG = create_model(opt)
dummyInput = torch.rand(1,3,32,32)
mode = 'onnx'
if mode == 'torchscript':
print("Tracing generator network..")
traced_netG = torch.jit.trace(netG, dummyInput)
traced_netG.save('../results/ts_generator.zip')
print(traced_netG.code)
for i, module in enumerate(traced_netG.RRDB_trunk.modules()):
print(i, str(module))
elif mode == 'onnx':
print("Performing onnx trace")
input_names = ["lr_input"]
output_names = ["hr_image"]
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12)
elif mode == 'memtrace':
criterion = torch.nn.MSELoss()
tracer = TorchCustomTrace()
tracer.install_hooks(netG)
out, = netG(dummyInput)
tracer.install_backward_hooks(out.grad_fn)
target = torch.zeros_like(out)
loss = criterion(out, target)
loss.backward()
elif mode == 'trace':
out = netG.forward(dummyInput)[0]
print(out.shape)
# Build the graph backwards.
graph = build_graph(out, 'output')
def get_unique_id_for_fn(fn):
return (str(fn).split(" object at ")[1])[:-1]
class GraphNode:
def __init__(self, fn):
self.name = (str(fn).split(" object at ")[0])[1:]
self.fn = fn
self.children = {}
self.parents = {}
def add_parent(self, parent):
self.parents[get_unique_id_for_fn(parent)] = parent
def add_child(self, child):
self.children[get_unique_id_for_fn(child)] = child
class TorchGraph:
def __init__(self):
self.tensor_map = {}
def get_node_for_tensor(self, t):
return self.tensor_map[get_unique_id_for_fn(t)]
def init(self, output_tensor):
self.build_graph_backwards(output_tensor.grad_fn, None)
# Find inputs
self.inputs = []
for v in self.tensor_map.values():
# Is an input if the parents dict is empty.
if bool(v.parents):
self.inputs.append(v)
def build_graph_backwards(self, fn, previous_fn):
id = get_unique_id_for_fn(fn)
if id in self.tensor_map:
node = self.tensor_map[id]
node.add_child(previous_fn)
else:
node = GraphNode(fn)
self.tensor_map[id] = node
# Propagate to children
for child_fn in fn.next_functions:
node.add_parent(self.build_graph_backwards(child_fn, fn))
return node