Move discriminators to the create_model paradigm
Also cleans up a lot of old discriminator models that I have no intention of using again.
This commit is contained in:
parent
7976a5825d
commit
193cdc6636
|
@ -37,13 +37,15 @@ class ByolDatasetWrapper(Dataset):
|
|||
self.cropped_img_size = opt['crop_size']
|
||||
self.key1 = opt_get(opt, ['key1'], 'hq')
|
||||
self.key2 = opt_get(opt, ['key2'], 'lq')
|
||||
for_sr = opt_get(opt, ['for_sr'], False) # When set, color alterations and blurs are disabled.
|
||||
|
||||
augmentations = [ \
|
||||
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
augs.RandomHorizontalFlip(),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
||||
if not for_sr:
|
||||
augmentations.extend([RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)])
|
||||
if opt['normalize']:
|
||||
# The paper calls for normalization. Most datasets/models in this repo don't use this.
|
||||
# Recommend setting true if you want to train exactly like the paper.
|
||||
|
|
|
@ -3,7 +3,9 @@ import torch.nn as nn
|
|||
|
||||
from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
||||
import torch.nn.functional as F
|
||||
from utils.util import checkpoint
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class Discriminator_VGG_128(nn.Module):
|
||||
|
@ -79,6 +81,12 @@ class Discriminator_VGG_128(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@register_model
|
||||
def register_discriminator_vgg_128(opt_net, opt):
|
||||
return Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'] / 128,
|
||||
extra_conv=opt_net['extra_conv'])
|
||||
|
||||
|
||||
class Discriminator_VGG_128_GN(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False, extra_conv=False):
|
||||
|
@ -159,514 +167,8 @@ class Discriminator_VGG_128_GN(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class CrossCompareBlock(nn.Module):
|
||||
def __init__(self, nf_in, nf_out):
|
||||
super(CrossCompareBlock, self).__init__()
|
||||
self.conv_hr_merge = ConvGnLelu(nf_in * 2, nf_in, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.proc_hr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.proc_lr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.reduce_hr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
|
||||
self.reduce_lr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
|
||||
|
||||
def forward(self, hr, lr):
|
||||
hr = self.conv_hr_merge(torch.cat([hr, lr], dim=1))
|
||||
hr = self.proc_hr(hr)
|
||||
hr = self.reduce_hr(hr)
|
||||
|
||||
lr = self.proc_lr(lr)
|
||||
lr = self.reduce_lr(lr)
|
||||
|
||||
return hr, lr
|
||||
|
||||
|
||||
class CrossCompareDiscriminator(nn.Module):
|
||||
def __init__(self, in_nc, ref_channels, nf, scale=4):
|
||||
super(CrossCompareDiscriminator, self).__init__()
|
||||
assert scale == 2 or scale == 4
|
||||
|
||||
self.init_conv_hr = ConvGnLelu(in_nc, nf, stride=2, norm=False, bias=True, activation=True)
|
||||
self.init_conv_lr = ConvGnLelu(ref_channels, nf, stride=1, norm=False, bias=True, activation=True)
|
||||
if scale == 4:
|
||||
strd_2 = 2
|
||||
else:
|
||||
strd_2 = 1
|
||||
self.second_conv = ConvGnLelu(nf, nf, stride=strd_2, norm=True, bias=False, activation=True)
|
||||
|
||||
self.cross1 = CrossCompareBlock(nf, nf * 2)
|
||||
self.cross2 = CrossCompareBlock(nf * 2, nf * 4)
|
||||
self.cross3 = CrossCompareBlock(nf * 4, nf * 8)
|
||||
self.cross4 = CrossCompareBlock(nf * 8, nf * 8)
|
||||
self.fproc_conv = ConvGnLelu(nf * 8, nf, norm=True, bias=True, activation=True)
|
||||
self.out_conv = ConvGnLelu(nf, 1, norm=False, bias=False, activation=False)
|
||||
|
||||
self.scale = scale * 16
|
||||
|
||||
def forward(self, hr, lr):
|
||||
hr = self.init_conv_hr(hr)
|
||||
hr = self.second_conv(hr)
|
||||
lr = self.init_conv_lr(lr)
|
||||
|
||||
hr, lr = self.cross1(hr, lr)
|
||||
hr, lr = self.cross2(hr, lr)
|
||||
hr, lr = self.cross3(hr, lr)
|
||||
hr, _ = self.cross4(hr, lr)
|
||||
|
||||
return self.out_conv(self.fproc_conv(hr)).view(-1, 1)
|
||||
|
||||
# Returns tuple of (number_output_channels, scale_of_output_reduction (1/n))
|
||||
def pixgan_parameters(self):
|
||||
return 3, self.scale
|
||||
|
||||
|
||||
class Discriminator_VGG_PixLoss(nn.Module):
|
||||
def __init__(self, in_nc, nf):
|
||||
super(Discriminator_VGG_PixLoss, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
|
||||
self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False)
|
||||
self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
|
||||
|
||||
# Pyramid network: upsample with residuals and produce losses at multiple resolutions.
|
||||
self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, activation=False)
|
||||
self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False)
|
||||
self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
||||
self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False)
|
||||
self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
|
||||
|
||||
self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, activation=False)
|
||||
self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False)
|
||||
self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
||||
self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False)
|
||||
self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, norm=False, activation=False)
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x, flatten=True):
|
||||
fea0 = self.lrelu(self.conv0_0(x))
|
||||
fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0)))
|
||||
|
||||
fea1 = self.lrelu(self.bn1_0(self.conv1_0(fea0)))
|
||||
fea1 = self.lrelu(self.bn1_1(self.conv1_1(fea1)))
|
||||
|
||||
fea2 = self.lrelu(self.bn2_0(self.conv2_0(fea1)))
|
||||
fea2 = self.lrelu(self.bn2_1(self.conv2_1(fea2)))
|
||||
|
||||
fea3 = self.lrelu(self.bn3_0(self.conv3_0(fea2)))
|
||||
fea3 = self.lrelu(self.bn3_1(self.conv3_1(fea3)))
|
||||
|
||||
fea4 = self.lrelu(self.bn4_0(self.conv4_0(fea3)))
|
||||
fea4 = self.lrelu(self.bn4_1(self.conv4_1(fea4)))
|
||||
|
||||
loss = self.reduce_1(fea4)
|
||||
# "Weight" all losses the same by interpolating them to the highest dimension.
|
||||
loss = self.pix_loss_collapse(loss)
|
||||
loss = F.interpolate(loss, scale_factor=4, mode="nearest")
|
||||
|
||||
# And the pyramid network!
|
||||
dec3 = self.up3_decimate(F.interpolate(fea4, scale_factor=2, mode="nearest"))
|
||||
dec3 = torch.cat([dec3, fea3], dim=1)
|
||||
dec3 = self.up3_converge(dec3)
|
||||
dec3 = self.up3_proc(dec3)
|
||||
loss3 = self.up3_reduce(dec3)
|
||||
loss3 = self.up3_pix(loss3)
|
||||
loss3 = F.interpolate(loss3, scale_factor=2, mode="nearest")
|
||||
|
||||
dec2 = self.up2_decimate(F.interpolate(dec3, scale_factor=2, mode="nearest"))
|
||||
dec2 = torch.cat([dec2, fea2], dim=1)
|
||||
dec2 = self.up2_converge(dec2)
|
||||
dec2 = self.up2_proc(dec2)
|
||||
dec2 = self.up2_reduce(dec2)
|
||||
loss2 = self.up2_pix(dec2)
|
||||
|
||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||
# then know how to handle them.
|
||||
combined_losses = torch.cat([loss, loss3, loss2], dim=1)
|
||||
return combined_losses.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 3, 8
|
||||
|
||||
|
||||
class Discriminator_UNet(nn.Module):
|
||||
def __init__(self, in_nc, nf):
|
||||
super(Discriminator_UNet, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
||||
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
|
||||
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
||||
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
|
||||
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
||||
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
|
||||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
def forward(self, x, flatten=True):
|
||||
fea0 = self.conv0_0(x)
|
||||
fea0 = self.conv0_1(fea0)
|
||||
|
||||
fea1 = self.conv1_0(fea0)
|
||||
fea1 = self.conv1_1(fea1)
|
||||
|
||||
fea2 = self.conv2_0(fea1)
|
||||
fea2 = self.conv2_1(fea2)
|
||||
|
||||
fea3 = self.conv3_0(fea2)
|
||||
fea3 = self.conv3_1(fea3)
|
||||
|
||||
fea4 = self.conv4_0(fea3)
|
||||
fea4 = self.conv4_1(fea4)
|
||||
|
||||
# And the pyramid network!
|
||||
u1 = self.up1(fea4, fea3)
|
||||
loss1 = self.collapse1(self.proc1(u1))
|
||||
u2 = self.up2(u1, fea2)
|
||||
loss2 = self.collapse2(self.proc2(u2))
|
||||
u3 = self.up3(u2, fea1)
|
||||
loss3 = self.collapse3(self.proc3(u3))
|
||||
res = loss3.shape[2:]
|
||||
|
||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||
# then know how to handle them.
|
||||
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
||||
F.interpolate(loss2, scale_factor=2),
|
||||
F.interpolate(loss3, scale_factor=1)], dim=1)
|
||||
return combined_losses.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 3, 4
|
||||
|
||||
|
||||
class Discriminator_UNet_FeaOut(nn.Module):
|
||||
def __init__(self, in_nc, nf, feature_mode=False):
|
||||
super(Discriminator_UNet_FeaOut, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
||||
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
|
||||
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
||||
self.fea_proc = ConvGnLelu(nf * 8, nf * 8, bias=True, norm=False, activation=False)
|
||||
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.feature_mode = feature_mode
|
||||
|
||||
def forward(self, x, output_feature_vector=False):
|
||||
fea0 = self.conv0_0(x)
|
||||
fea0 = self.conv0_1(fea0)
|
||||
|
||||
fea1 = self.conv1_0(fea0)
|
||||
fea1 = self.conv1_1(fea1)
|
||||
|
||||
fea2 = self.conv2_0(fea1)
|
||||
fea2 = self.conv2_1(fea2)
|
||||
|
||||
fea3 = self.conv3_0(fea2)
|
||||
fea3 = self.conv3_1(fea3)
|
||||
|
||||
fea4 = self.conv4_0(fea3)
|
||||
fea4 = self.conv4_1(fea4)
|
||||
|
||||
# And the pyramid network!
|
||||
u1 = self.up1(fea4, fea3)
|
||||
loss1 = self.collapse1(self.proc1(u1))
|
||||
fea_out = self.fea_proc(u1)
|
||||
|
||||
combined_losses = F.interpolate(loss1, scale_factor=4)
|
||||
if output_feature_vector:
|
||||
return combined_losses.view(-1, 1), fea_out
|
||||
else:
|
||||
return combined_losses.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 1, 4
|
||||
|
||||
|
||||
class Vgg128GnHead(nn.Module):
|
||||
def __init__(self, in_nc, nf, depth=5):
|
||||
super(Vgg128GnHead, self).__init__()
|
||||
assert depth == 4 or depth == 5 # Nothing stopping others from being implemented, just not done yet.
|
||||
self.depth = depth
|
||||
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
if depth > 4:
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.lrelu(self.conv0_0(x))
|
||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||
|
||||
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
||||
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||
|
||||
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||
|
||||
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||
|
||||
if self.depth > 4:
|
||||
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||
return fea
|
||||
|
||||
|
||||
class RefDiscriminatorVgg128(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, in_nc, nf, input_img_factor=1):
|
||||
super(RefDiscriminatorVgg128, self).__init__()
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
self.feature_head = Vgg128GnHead(in_nc, nf)
|
||||
self.ref_head = Vgg128GnHead(in_nc+1, nf, depth=4)
|
||||
final_nf = nf * 8
|
||||
|
||||
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 512)
|
||||
self.ref_linear = nn.Linear(nf * 8, 128)
|
||||
|
||||
self.output_linears = nn.Sequential(
|
||||
nn.Linear(128+512, 512),
|
||||
self.lrelu,
|
||||
nn.Linear(512, 256),
|
||||
self.lrelu,
|
||||
nn.Linear(256, 128),
|
||||
self.lrelu,
|
||||
nn.Linear(128, 1)
|
||||
)
|
||||
|
||||
def forward(self, x, ref, ref_center_point):
|
||||
ref = self.ref_head(ref)
|
||||
ref_center_point = ref_center_point // 16
|
||||
from models.SwitchedResidualGenerator_arch import gather_2d
|
||||
ref_vector = gather_2d(ref, ref_center_point)
|
||||
ref_vector = self.ref_linear(ref_vector)
|
||||
|
||||
fea = self.feature_head(x)
|
||||
fea = fea.contiguous().view(fea.size(0), -1)
|
||||
fea = self.lrelu(self.linear1(fea))
|
||||
|
||||
out = self.output_linears(torch.cat([fea, ref_vector], dim=1))
|
||||
return out
|
||||
|
||||
|
||||
class PsnrApproximator(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, nf, input_img_factor=1):
|
||||
super(PsnrApproximator, self).__init__()
|
||||
|
||||
# [64, 128, 128]
|
||||
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
|
||||
# [64, 128, 128]
|
||||
self.real_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.real_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.real_bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.real_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.real_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
self.real_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.real_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.real_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.real_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.real_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.real_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
|
||||
# [512, 16, 16]
|
||||
self.conv3_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
final_nf = nf * 8
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 1024)
|
||||
self.linear2 = nn.Linear(1024, 512)
|
||||
self.linear3 = nn.Linear(512, 128)
|
||||
self.linear4 = nn.Linear(128, 1)
|
||||
|
||||
def compute_body1(self, real):
|
||||
fea = self.lrelu(self.real_conv0_0(real))
|
||||
fea = self.lrelu(self.real_bn0_1(self.real_conv0_1(fea)))
|
||||
fea = self.lrelu(self.real_bn1_0(self.real_conv1_0(fea)))
|
||||
fea = self.lrelu(self.real_bn1_1(self.real_conv1_1(fea)))
|
||||
fea = self.lrelu(self.real_bn2_0(self.real_conv2_0(fea)))
|
||||
fea = self.lrelu(self.real_bn2_1(self.real_conv2_1(fea)))
|
||||
return fea
|
||||
|
||||
def compute_body2(self, fake):
|
||||
fea = self.lrelu(self.fake_conv0_0(fake))
|
||||
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
|
||||
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
|
||||
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
|
||||
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
|
||||
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
|
||||
return fea
|
||||
|
||||
def forward(self, real, fake):
|
||||
real_fea = checkpoint(self.compute_body1, real)
|
||||
fake_fea = checkpoint(self.compute_body2, fake)
|
||||
fea = torch.cat([real_fea, fake_fea], dim=1)
|
||||
|
||||
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||
|
||||
fea = fea.contiguous().view(fea.size(0), -1)
|
||||
fea = self.lrelu(self.linear1(fea))
|
||||
fea = self.lrelu(self.linear2(fea))
|
||||
fea = self.lrelu(self.linear3(fea))
|
||||
out = self.linear4(fea)
|
||||
return out.squeeze()
|
||||
|
||||
|
||||
class SingleImageQualityEstimator(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, nf, input_img_factor=1):
|
||||
super(SingleImageQualityEstimator, self).__init__()
|
||||
|
||||
# [64, 128, 128]
|
||||
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
|
||||
# [512, 16, 16]
|
||||
self.conv3_0 = nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False)
|
||||
self.bn3_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.conv3_1 = nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 2, 3, 1, 1, bias=True)
|
||||
self.conv4_2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
|
||||
self.conv4_3 = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=.2, inplace=True)
|
||||
|
||||
def compute_body(self, fake):
|
||||
fea = self.lrelu(self.fake_conv0_0(fake))
|
||||
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
|
||||
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
|
||||
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
|
||||
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
|
||||
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
|
||||
return fea
|
||||
|
||||
def forward(self, fake):
|
||||
fea = checkpoint(self.compute_body, fake)
|
||||
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||
fea = self.lrelu(self.conv4_0(fea))
|
||||
fea = self.lrelu(self.conv4_1(fea))
|
||||
fea = self.lrelu(self.conv4_2(fea))
|
||||
fea = self.sigmoid(self.conv4_3(fea))
|
||||
return fea
|
||||
@register_model
|
||||
def register_discriminator_vgg_128(opt_net, opt):
|
||||
return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'],
|
||||
extra_conv=opt_get(opt_net, ['extra_conv'], False),
|
||||
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False))
|
||||
|
|
|
@ -5,6 +5,9 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class BlurLayer(nn.Module):
|
||||
def __init__(self, kernel=None, normalize=True, flip=False, stride=1):
|
||||
|
@ -372,4 +375,9 @@ class StyleGanDiscriminator(nn.Module):
|
|||
else:
|
||||
raise KeyError("Unknown structure: ", self.structure)
|
||||
|
||||
return scores_out
|
||||
return scores_out
|
||||
|
||||
|
||||
@register_model
|
||||
def register_stylegan_vgg(opt_net, opt):
|
||||
return StyleGanDiscriminator(opt_get(opt_net, ['image_size'], 128))
|
|
@ -18,7 +18,7 @@ from torch.autograd import grad as torch_grad
|
|||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -763,7 +763,7 @@ class DiscriminatorBlock(nn.Module):
|
|||
|
||||
class StyleGan2Discriminator(nn.Module):
|
||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||
transparent=False, fmap_max=512, input_filters=3):
|
||||
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False):
|
||||
super().__init__()
|
||||
num_layers = int(log2(image_size) - 1)
|
||||
|
||||
|
@ -789,12 +789,16 @@ class StyleGan2Discriminator(nn.Module):
|
|||
|
||||
attn_blocks.append(attn_fn)
|
||||
|
||||
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
|
||||
quantize_blocks.append(quantize_fn)
|
||||
if quantize:
|
||||
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
|
||||
quantize_blocks.append(quantize_fn)
|
||||
else:
|
||||
quantize_blocks.append(None)
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.attn_blocks = nn.ModuleList(attn_blocks)
|
||||
self.quantize_blocks = nn.ModuleList(quantize_blocks)
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
chan_last = filters[-1]
|
||||
latent_dim = 2 * 2 * chan_last
|
||||
|
@ -811,7 +815,10 @@ class StyleGan2Discriminator(nn.Module):
|
|||
quantize_loss = torch.zeros(1).to(x)
|
||||
|
||||
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
|
||||
x = block(x)
|
||||
if self.do_checkpointing:
|
||||
x = checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
if exists(attn_block):
|
||||
x = attn_block(x)
|
||||
|
@ -862,7 +869,6 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
|||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
divergence_loss = divergence_loss + gp
|
||||
|
@ -877,17 +883,14 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
|
|||
self.w_styles = opt['w_styles']
|
||||
self.gen = opt['gen']
|
||||
self.pl_mean = None
|
||||
from models.archs.stylegan.stylegan2_lucidrains import EMA
|
||||
self.pl_length_ma = EMA(.99)
|
||||
|
||||
def forward(self, net, state):
|
||||
w_styles = state[self.w_styles]
|
||||
gen = state[self.gen]
|
||||
from models.stylegan.stylegan2_lucidrains import calc_pl_lengths
|
||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||
|
||||
from models.stylegan.stylegan2_lucidrains import is_empty
|
||||
if not is_empty(self.pl_mean):
|
||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||
if not torch.isnan(pl_loss):
|
||||
|
@ -906,3 +909,12 @@ def register_stylegan2_lucidrains(opt_net, opt):
|
|||
return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||
attn_layers=attn)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_stylegan2_discriminator(opt_net, opt):
|
||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
|
||||
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
|
||||
quantize=opt_get(opt_net, ['quantize'], False))
|
||||
return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||
|
|
|
@ -11,7 +11,7 @@ munch
|
|||
tqdm
|
||||
scp
|
||||
tensorboard
|
||||
pytorch_fid
|
||||
pytorch_fid==0.1.1
|
||||
kornia
|
||||
linear_attention_transformer
|
||||
vector_quantize_pytorch
|
||||
|
|
|
@ -1,104 +0,0 @@
|
|||
import sys
|
||||
import os.path as osp
|
||||
import math
|
||||
import torchvision.utils
|
||||
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
from data import create_dataloader, create_dataset # noqa: E402
|
||||
from utils import util # noqa: E402
|
||||
|
||||
|
||||
def main():
|
||||
dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
|
||||
opt = {}
|
||||
opt['dist'] = False
|
||||
opt['gpu_ids'] = [0]
|
||||
if dataset == 'REDS':
|
||||
opt['name'] = 'test_REDS'
|
||||
opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
|
||||
opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
|
||||
opt['mode'] = 'REDS'
|
||||
opt['N_frames'] = 5
|
||||
opt['phase'] = 'train'
|
||||
opt['use_shuffle'] = True
|
||||
opt['n_workers'] = 8
|
||||
opt['batch_size'] = 16
|
||||
opt['target_size'] = 256
|
||||
opt['LQ_size'] = 64
|
||||
opt['scale'] = 4
|
||||
opt['use_flip'] = True
|
||||
opt['use_rot'] = True
|
||||
opt['interval_list'] = [1]
|
||||
opt['random_reverse'] = False
|
||||
opt['border_mode'] = False
|
||||
opt['cache_keys'] = None
|
||||
opt['data_type'] = 'lmdb' # img | lmdb | mc
|
||||
elif dataset == 'Vimeo90K':
|
||||
opt['name'] = 'test_Vimeo90K'
|
||||
opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
||||
opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
|
||||
opt['mode'] = 'Vimeo90K'
|
||||
opt['N_frames'] = 7
|
||||
opt['phase'] = 'train'
|
||||
opt['use_shuffle'] = True
|
||||
opt['n_workers'] = 8
|
||||
opt['batch_size'] = 16
|
||||
opt['target_size'] = 256
|
||||
opt['LQ_size'] = 64
|
||||
opt['scale'] = 4
|
||||
opt['use_flip'] = True
|
||||
opt['use_rot'] = True
|
||||
opt['interval_list'] = [1]
|
||||
opt['random_reverse'] = False
|
||||
opt['border_mode'] = False
|
||||
opt['cache_keys'] = None
|
||||
opt['data_type'] = 'lmdb' # img | lmdb | mc
|
||||
elif dataset == 'DIV2K800_sub':
|
||||
opt['name'] = 'DIV2K800'
|
||||
opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
||||
opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
|
||||
opt['mode'] = 'LQGT'
|
||||
opt['phase'] = 'train'
|
||||
opt['use_shuffle'] = True
|
||||
opt['n_workers'] = 8
|
||||
opt['batch_size'] = 16
|
||||
opt['target_size'] = 128
|
||||
opt['scale'] = 4
|
||||
opt['use_flip'] = True
|
||||
opt['use_rot'] = True
|
||||
opt['color'] = 'RGB'
|
||||
opt['data_type'] = 'lmdb' # img | lmdb
|
||||
else:
|
||||
raise ValueError('Please implement by yourself.')
|
||||
|
||||
util.mkdir('tmp')
|
||||
train_set = create_dataset(opt)
|
||||
train_loader = create_dataloader(train_set, opt, opt, None)
|
||||
nrow = int(math.sqrt(opt['batch_size']))
|
||||
padding = 2 if opt['phase'] == 'train' else 0
|
||||
|
||||
print('start...')
|
||||
for i, data in enumerate(train_loader):
|
||||
if i > 5:
|
||||
break
|
||||
print(i)
|
||||
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
||||
LQs = data['LQs']
|
||||
else:
|
||||
LQ = data['lq']
|
||||
GT = data['hq']
|
||||
|
||||
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
||||
for j in range(LQs.size(1)):
|
||||
torchvision.utils.save_image(LQs[:, j, :, :, :],
|
||||
'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
|
||||
padding=padding, normalize=False)
|
||||
else:
|
||||
torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
|
||||
padding=padding, normalize=False)
|
||||
torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
|
||||
normalize=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -293,7 +293,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_styled_sr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -57,11 +57,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
new_net = None
|
||||
if net['type'] == 'generator':
|
||||
if new_net is None:
|
||||
new_net = networks.create_model(opt, net, opt['scale']).to(self.device)
|
||||
new_net = networks.create_model(opt, net).to(self.device)
|
||||
self.netsG[name] = new_net
|
||||
elif net['type'] == 'discriminator':
|
||||
if new_net is None:
|
||||
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
|
||||
new_net = networks.create_model(opt, net).to(self.device)
|
||||
self.netsD[name] = new_net
|
||||
else:
|
||||
raise NotImplementedError("Can only handle generators and discriminators")
|
||||
|
|
|
@ -1,18 +1,11 @@
|
|||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from inspect import isfunction, getmembers
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import models.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.feature_arch as feature_arch
|
||||
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -63,7 +56,7 @@ class CreateModelError(Exception):
|
|||
f'{available}')
|
||||
|
||||
|
||||
def create_model(opt, opt_net, scale=None):
|
||||
def create_model(opt, opt_net):
|
||||
which_model = opt_net['which_model']
|
||||
# For backwards compatibility.
|
||||
if not which_model:
|
||||
|
@ -76,96 +69,6 @@ def create_model(opt, opt_net, scale=None):
|
|||
return registered_fns[which_model](opt_net, opt)
|
||||
|
||||
|
||||
class GradDiscWrapper(torch.nn.Module):
|
||||
def __init__(self, m):
|
||||
super(GradDiscWrapper, self).__init__()
|
||||
logger.info("Wrapping a discriminator..")
|
||||
self.m = m
|
||||
|
||||
def forward(self, x):
|
||||
return self.m(x)
|
||||
|
||||
def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||
which_model = opt_net['which_model_D']
|
||||
|
||||
if 'image_size' in opt_net.keys():
|
||||
img_sz = opt_net['image_size']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
|
||||
elif which_model == 'discriminator_vgg_128_gn':
|
||||
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
|
||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
|
||||
input_img_factor=img_sz / 128, extra_conv=extra_conv)
|
||||
if wrap:
|
||||
netD = GradDiscWrapper(netD)
|
||||
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
|
||||
elif which_model == 'stylegan_vgg':
|
||||
netD = StyleGanDiscriminator(128)
|
||||
elif which_model == 'discriminator_resnet':
|
||||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'discriminator_resnet_50':
|
||||
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'resnext':
|
||||
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
|
||||
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
|
||||
#netD.load_state_dict(state_dict, strict=False)
|
||||
netD.fc = torch.nn.Linear(512 * 4, 1)
|
||||
elif which_model == 'discriminator_pix':
|
||||
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
elif which_model == "discriminator_unet":
|
||||
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
elif which_model == "discriminator_unet_fea":
|
||||
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
|
||||
elif which_model == "discriminator_switched":
|
||||
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
|
||||
final_temperature_step=opt_net['final_temperature_step'])
|
||||
elif which_model == "cross_compare_vgg128":
|
||||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
|
||||
elif which_model == "discriminator_refvgg":
|
||||
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||
elif which_model == "psnr_approximator":
|
||||
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||
elif which_model == "stylegan2_discriminator":
|
||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||
from models.stylegan.stylegan2_lucidrains import StyleGan2Discriminator
|
||||
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
||||
from models.stylegan.stylegan2_lucidrains import StyleGan2Augmentor
|
||||
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
||||
# Discriminator
|
||||
def define_D(opt, wrap=False):
|
||||
img_sz = opt['datasets']['train']['target_size']
|
||||
opt_net = opt['network_D']
|
||||
return define_D_net(opt_net, img_sz, wrap=wrap)
|
||||
|
||||
def define_fixed_D(opt):
|
||||
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
|
||||
net = define_D_net(opt)
|
||||
|
||||
# Load the model parameters:
|
||||
load_net = torch.load(opt['pretrained_path'])
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
net.load_state_dict(load_net_clean)
|
||||
|
||||
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
||||
net.eval()
|
||||
for k, v in net.named_parameters():
|
||||
v.requires_grad = False
|
||||
net.fdisc_weight = opt['weight']
|
||||
|
||||
return net
|
||||
|
||||
|
||||
# Define network used for perceptual loss
|
||||
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
|
||||
if which_model == 'vgg':
|
||||
|
|
|
@ -1,177 +0,0 @@
|
|||
import argparse
|
||||
import functools
|
||||
import torch
|
||||
from utils import options as option
|
||||
from trainer.networks import create_model
|
||||
|
||||
|
||||
class TracedModule:
|
||||
def __init__(self, idname):
|
||||
self.idname = idname
|
||||
self.traced_outputs = []
|
||||
self.traced_inputs = []
|
||||
|
||||
|
||||
class TorchCustomTrace:
|
||||
def __init__(self):
|
||||
self.module_name_counter = {}
|
||||
self.modules = {}
|
||||
self.graph = {}
|
||||
self.module_map_by_inputs = {}
|
||||
self.module_map_by_outputs = {}
|
||||
self.inputs_to_func_output_tuple = {}
|
||||
|
||||
def add_tracked_module(self, mod: torch.nn.Module):
|
||||
modname = type(mod).__name__
|
||||
if modname not in self.module_name_counter.keys():
|
||||
self.module_name_counter[modname] = 0
|
||||
self.module_name_counter[modname] += 1
|
||||
idname = "%s(%03d)" % (modname, self.module_name_counter[modname])
|
||||
self.modules[idname] = TracedModule(idname)
|
||||
return idname
|
||||
|
||||
# Only called for nn.Modules since those are the only things we can access. Filling in the gaps will be done in
|
||||
# the backwards pass.
|
||||
def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str):
|
||||
mod = self.modules[mod_id]
|
||||
'''
|
||||
for li in inputs:
|
||||
if type(li) == torch.Tensor:
|
||||
li = [li]
|
||||
if type(li) == list:
|
||||
for i in li:
|
||||
if i.data_ptr() in self.module_map_by_inputs.keys():
|
||||
self.module_map_by_inputs[i.data_ptr()].append(mod)
|
||||
else:
|
||||
self.module_map_by_inputs[i.data_ptr()] = [mod]
|
||||
for o in outputs:
|
||||
if o.data_ptr() in self.module_map_by_inputs.keys():
|
||||
self.module_map_by_inputs[o.data_ptr()].append(mod)
|
||||
else:
|
||||
self.module_map_by_inputs[o.data_ptr()] = [mod]
|
||||
'''
|
||||
print(trace)
|
||||
|
||||
def mem_backward_hook(self, inputs, outputs, op):
|
||||
if len(inputs) == 0:
|
||||
print("No inputs.. %s" % (op,))
|
||||
outs = [o.data_ptr() for o in outputs]
|
||||
tup = (outs, op)
|
||||
#print(tup)
|
||||
for li in inputs:
|
||||
if type(li) == torch.Tensor:
|
||||
li = [li]
|
||||
if type(li) == list:
|
||||
for i in li:
|
||||
if i.data_ptr() in self.module_map_by_inputs.keys():
|
||||
print("%i: [%s] {%s}" % (i.data_ptr(), op, [n.idname for n in self.module_map_by_inputs[i.data_ptr()]]))
|
||||
if i.data_ptr() in self.inputs_to_func_output_tuple.keys():
|
||||
self.inputs_to_func_output_tuple[i.data_ptr()].append(tup)
|
||||
else:
|
||||
self.inputs_to_func_output_tuple[i.data_ptr()] = [tup]
|
||||
|
||||
def install_hooks(self, mod: torch.nn.Module, trace=""):
|
||||
mod_id = self.add_tracked_module(mod)
|
||||
my_trace = trace + "->" + mod_id
|
||||
# If this module has parameters, it also has a state worth tracking.
|
||||
#if next(mod.parameters(recurse=False), None) is not None:
|
||||
mod.register_forward_hook(functools.partial(self.mem_forward_hook, trace=my_trace, mod_id=mod_id))
|
||||
|
||||
for m in mod.children():
|
||||
self.install_hooks(m, my_trace)
|
||||
|
||||
def install_backward_hooks(self, grad_fn):
|
||||
# AccumulateGrad simply pushes a gradient into the specified variable, and isn't useful for the purposes of
|
||||
# tracing the graph.
|
||||
if grad_fn is None or "AccumulateGrad" in str(grad_fn):
|
||||
return
|
||||
grad_fn.register_hook(functools.partial(self.mem_backward_hook, op=str(grad_fn)))
|
||||
for g, _ in grad_fn.next_functions:
|
||||
self.install_backward_hooks(g)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
netG = create_model(opt)
|
||||
dummyInput = torch.rand(1,3,32,32)
|
||||
|
||||
mode = 'onnx'
|
||||
if mode == 'torchscript':
|
||||
print("Tracing generator network..")
|
||||
traced_netG = torch.jit.trace(netG, dummyInput)
|
||||
traced_netG.save('../results/ts_generator.zip')
|
||||
|
||||
print(traced_netG.code)
|
||||
for i, module in enumerate(traced_netG.RRDB_trunk.modules()):
|
||||
print(i, str(module))
|
||||
elif mode == 'onnx':
|
||||
print("Performing onnx trace")
|
||||
input_names = ["lr_input"]
|
||||
output_names = ["hr_image"]
|
||||
dynamic_axes = {'lr_input': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}, 'hr_image': {0: 'batch', 1: 'filters', 2: 'h', 3: 'w'}}
|
||||
|
||||
torch.onnx.export(netG, dummyInput, "../results/gen.onnx", verbose=True, input_names=input_names,
|
||||
output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12)
|
||||
elif mode == 'memtrace':
|
||||
criterion = torch.nn.MSELoss()
|
||||
tracer = TorchCustomTrace()
|
||||
tracer.install_hooks(netG)
|
||||
out, = netG(dummyInput)
|
||||
tracer.install_backward_hooks(out.grad_fn)
|
||||
target = torch.zeros_like(out)
|
||||
loss = criterion(out, target)
|
||||
loss.backward()
|
||||
elif mode == 'trace':
|
||||
out = netG.forward(dummyInput)[0]
|
||||
print(out.shape)
|
||||
# Build the graph backwards.
|
||||
graph = build_graph(out, 'output')
|
||||
|
||||
def get_unique_id_for_fn(fn):
|
||||
return (str(fn).split(" object at ")[1])[:-1]
|
||||
|
||||
class GraphNode:
|
||||
def __init__(self, fn):
|
||||
self.name = (str(fn).split(" object at ")[0])[1:]
|
||||
self.fn = fn
|
||||
self.children = {}
|
||||
self.parents = {}
|
||||
|
||||
def add_parent(self, parent):
|
||||
self.parents[get_unique_id_for_fn(parent)] = parent
|
||||
|
||||
def add_child(self, child):
|
||||
self.children[get_unique_id_for_fn(child)] = child
|
||||
|
||||
class TorchGraph:
|
||||
def __init__(self):
|
||||
self.tensor_map = {}
|
||||
|
||||
def get_node_for_tensor(self, t):
|
||||
return self.tensor_map[get_unique_id_for_fn(t)]
|
||||
|
||||
def init(self, output_tensor):
|
||||
self.build_graph_backwards(output_tensor.grad_fn, None)
|
||||
# Find inputs
|
||||
self.inputs = []
|
||||
for v in self.tensor_map.values():
|
||||
# Is an input if the parents dict is empty.
|
||||
if bool(v.parents):
|
||||
self.inputs.append(v)
|
||||
|
||||
def build_graph_backwards(self, fn, previous_fn):
|
||||
id = get_unique_id_for_fn(fn)
|
||||
if id in self.tensor_map:
|
||||
node = self.tensor_map[id]
|
||||
node.add_child(previous_fn)
|
||||
else:
|
||||
node = GraphNode(fn)
|
||||
self.tensor_map[id] = node
|
||||
# Propagate to children
|
||||
for child_fn in fn.next_functions:
|
||||
node.add_parent(self.build_graph_backwards(child_fn, fn))
|
||||
return node
|
Loading…
Reference in New Issue
Block a user