More spring cleaning

pull/9/head
James Betker 2022-03-16 12:04:00 +07:00
parent 735f6e4640
commit d186414566
51 changed files with 60 additions and 1028 deletions

@ -9,7 +9,7 @@ from torchvision import transforms
import torch.nn as nn
from pathlib import Path
import models.stylegan.stylegan2_lucidrains as sg2
import models.image_generation.stylegan.stylegan2_lucidrains as sg2
def convert_transparent_to_rgb(image):

@ -1016,3 +1016,17 @@ class FinalUpsampleBlock2x(nn.Module):
def forward(self, x):
return self.chain(x)
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
def gather_2d(input, index):
b, c, h, w = input.shape
nodim = input.view(b, c, h * w)
ind_nd = index[:, 0]*w + index[:, 1]
ind_nd = ind_nd.unsqueeze(1)
ind_nd = ind_nd.repeat((1, c))
ind_nd = ind_nd.unsqueeze(2)
result = torch.gather(nodim, dim=2, index=ind_nd)
result = result.squeeze()
if b == 1:
result = result.unsqueeze(0)
return result

@ -2,4 +2,5 @@ from models.audio.tts.tacotron2.taco_utils import *
from models.audio.tts.tacotron2.text import *
from models.audio.tts.tacotron2.tacotron2 import *
from models.audio.tts.tacotron2.stft import *
from models.audio.tts.tacotron2.layers import *
from models.audio.tts.tacotron2.layers import *
from models.audio.tts.tacotron2.loss import *

@ -1,96 +0,0 @@
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
# Utilizes pretrained torchvision modules for feature extraction
class VGGFeatureExtractor(nn.Module):
def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True,
device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
else:
model = torchvision.models.vgg19(pretrained=True)
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.feature_layers = feature_layers
self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x, interpolate_factor=1):
if interpolate_factor > 1:
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class TrainableVGGFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
device=torch.device('cpu')):
super(TrainableVGGFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=False)
else:
model = torchvision.models.vgg19(pretrained=False)
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
def forward(self, x, interpolate_factor=1):
if interpolate_factor > 1:
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
# Assume input range is [0, 1]
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class WideResnetFeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
print("Using wide resnet extractor.")
super(WideResnetFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
self.model = torchvision.models.wide_resnet50_2(pretrained=True)
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
# No need to BP to variable
for p in self.model.parameters():
p.requires_grad = False
def forward(self, x):
# Assume input range is [0, 1]
if self.use_input_norm:
x = (x - self.mean) / self.std
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
return x

@ -3,13 +3,13 @@ import math
import torch.nn as nn
import torch
from models.RRDBNet_arch import RRDB
from models.image_generation.RRDBNet_arch import RRDB
from models.arch_util import ConvGnLelu
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
from models.glean.stylegan2_latent_bank import Stylegan2LatentBank
from models.stylegan.stylegan2_rosinality import EqualLinear
from models.image_generation.glean.stylegan2_latent_bank import Stylegan2LatentBank
from models.image_generation.stylegan.stylegan2_rosinality import EqualLinear
from trainer.networks import register_model
from utils.util import checkpoint, sequential_checkpoint

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from models.arch_util import ConvGnLelu
from models.stylegan.stylegan2_rosinality import Generator
from models.image_generation.stylegan.stylegan2_rosinality import Generator
class Stylegan2LatentBank(nn.Module):

@ -20,7 +20,7 @@ from torch import nn, einsum
from torch.utils.data import Dataset
from torchvision import transforms
from models.stylegan.stylegan2_lucidrains import gradient_penalty
from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty
from trainer.networks import register_model
from utils.util import opt_get

@ -1,7 +1,7 @@
import torch
from torch import nn as nn
from models.srflow import thops
from models.image_generation.srflow import thops
class _ActNorm(nn.Module):

@ -1,8 +1,8 @@
import torch
from torch import nn as nn
from models.srflow import thops
from models.srflow.flow import Conv2d, Conv2dZeros
from models.image_generation.srflow import thops
from models.image_generation.srflow.flow import Conv2d, Conv2dZeros
from utils.util import opt_get

@ -1,9 +1,9 @@
import torch
from torch import nn as nn
import models.srflow.Permutations
import models.srflow.FlowAffineCouplingsAblation
import models.srflow.FlowActNorms
import models.image_generation.srflow.Permutations
import models.image_generation.srflow.FlowAffineCouplingsAblation
import models.image_generation.srflow.FlowActNorms
def getConditional(rrdbResults, position):
@ -46,17 +46,17 @@ class FlowStep(nn.Module):
self.acOpt = acOpt
# 1. actnorm
self.actnorm = models.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
self.actnorm = models.image_generation.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
# 2. permute
if flow_permutation == "invconv":
self.invconv = models.srflow.Permutations.InvertibleConv1x1(
self.invconv = models.image_generation.srflow.Permutations.InvertibleConv1x1(
in_channels, LU_decomposed=LU_decomposed)
# 3. coupling
if flow_coupling == "CondAffineSeparatedAndCond":
self.affine = models.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
opt=opt)
self.affine = models.image_generation.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
opt=opt)
elif flow_coupling == "noCoupling":
pass
else:

@ -2,12 +2,11 @@ import numpy as np
import torch
from torch import nn as nn
import models.srflow.Split
from models.srflow import flow
from models.srflow import thops
from models.srflow.Split import Split2d
from models.srflow.glow_arch import f_conv2d_bias
from models.srflow.FlowStep import FlowStep
import models.image_generation.srflow.Split
from models.image_generation.srflow import flow
from models.image_generation.srflow.Split import Split2d
from models.image_generation.srflow.glow_arch import f_conv2d_bias
from models.image_generation.srflow.FlowStep import FlowStep
from utils.util import opt_get, checkpoint
@ -146,8 +145,8 @@ class FlowUpsamplerNet(nn.Module):
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
if t == 'Split2d':
split = models.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
split = models.image_generation.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass

@ -3,7 +3,7 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from models.srflow import thops
from models.image_generation.srflow import thops
class InvertibleConv1x1(nn.Module):

@ -2,7 +2,7 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.srflow.module_util as mutil
import models.image_generation.srflow.module_util as mutil
from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
from trainer.networks import register_model
from utils.util import opt_get

@ -4,10 +4,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.srflow.RRDBNet_arch import RRDBNet
from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet
import models.srflow.thops as thops
import models.srflow.flow as flow
from models.image_generation.srflow.RRDBNet_arch import RRDBNet
from models.image_generation.srflow.FlowUpsamplerNet import FlowUpsamplerNet
import models.image_generation.srflow.thops as thops
import models.image_generation.srflow.flow as flow
from trainer.networks import register_model
from utils.util import opt_get

@ -1,8 +1,8 @@
import torch
from torch import nn as nn
from models.srflow import thops
from models.srflow.flow import Conv2dZeros, GaussianDiag
from models.image_generation.srflow import thops
from models.image_generation.srflow.flow import Conv2dZeros, GaussianDiag
from utils.util import opt_get

@ -1,9 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.srflow.FlowActNorms import ActNorm2d
from models.image_generation.srflow.FlowActNorms import ActNorm2d
from . import thops

@ -2,10 +2,10 @@
def create_stylegan2_loss(opt_loss, env):
type = opt_loss['type']
if type == 'stylegan2_divergence':
import models.stylegan.stylegan2_lucidrains as stylegan2
import models.image_generation.stylegan.stylegan2_lucidrains as stylegan2
return stylegan2.StyleGan2DivergenceLoss(opt_loss, env)
elif type == 'stylegan2_pathlen':
import models.stylegan.stylegan2_lucidrains as stylegan2
import models.image_generation.stylegan.stylegan2_lucidrains as stylegan2
return stylegan2.StyleGan2PathLengthLoss(opt_loss, env)
else:
raise NotImplementedError

@ -5,7 +5,7 @@ import torch.nn.functional as F
from torch import nn
from data.byol_attachment import reconstructed_shared_regions
from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
from models.image_latents.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
update_moving_average
from trainer.networks import create_model, register_model
from utils.util import checkpoint

@ -1,123 +0,0 @@
# A direct copy of torchvision's resnet.py modified to support gradient checkpointing.
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck
import torchvision
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
from trainer.networks import register_model
from utils.util import checkpoint
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
class Backbone(torchvision.models.resnet.ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
replace_stride_with_dilation, norm_layer)
del self.fc
del self.avgpool
def _forward_impl(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
l1 = checkpoint(self.layer1, x)
l2 = checkpoint(self.layer2, l1)
l3 = checkpoint(self.layer3, l2)
l4 = checkpoint(self.layer4, l3)
return l1, l2, l3, l4
def forward(self, x):
return self._forward_impl(x)
def _backbone(arch, block, layers, pretrained, progress, **kwargs):
model = Backbone(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def backbone18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _backbone('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def backbone34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _backbone('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def backbone50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _backbone('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def backbone101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _backbone('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def backbone152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)

@ -1,131 +0,0 @@
import math
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
from models.segformer.backbone import backbone50
from trainer.networks import register_model
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
def gather_2d(input, index):
b, c, h, w = input.shape
nodim = input.view(b, c, h * w)
ind_nd = index[:, 0]*w + index[:, 1]
ind_nd = ind_nd.unsqueeze(1)
ind_nd = ind_nd.repeat((1, c))
ind_nd = ind_nd.unsqueeze(2)
result = torch.gather(nodim, dim=2, index=ind_nd)
result = result.squeeze()
if b == 1:
result = result.unsqueeze(0)
return result
class DilatorModule(nn.Module):
def __init__(self, input_channels, output_channels, max_dilation):
super().__init__()
self.max_dilation = max_dilation
self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=1, bias=True)
if max_dilation > 1:
self.bn = nn.BatchNorm2d(input_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=max_dilation, dilation=max_dilation, bias=True)
self.dense = nn.Linear(input_channels, output_channels, bias=True)
def forward(self, inp, loc):
x = self.conv1(inp)
if self.max_dilation > 1:
x = self.bn(self.relu(x))
x = self.conv2(x)
# This can be made more efficient by only computing these convolutions across a subset of the image. Possibly.
x = gather_2d(x, loc).contiguous()
return self.dense(x)
# Grabbed from torch examples: https://github.com/pytorch/examples/tree/master/https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65:7
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return x
# Simple mean() layer encoded into a class so that BYOL can grab it.
class Tail(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.mean(dim=0)
class Segformer(nn.Module):
def __init__(self, latent_channels=1024, layers=8):
super().__init__()
self.backbone = backbone50()
backbone_channels = [256, 512, 1024, 2048]
dilations = [[1,2,3,4],[1,2,3],[1,2],[1]]
final_latent_channels = latent_channels
dilators = []
for ic, dis in zip(backbone_channels, dilations):
layer_dilators = []
for di in dis:
layer_dilators.append(DilatorModule(ic, final_latent_channels, di))
dilators.append(nn.ModuleList(layer_dilators))
self.dilators = nn.ModuleList(dilators)
self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10)
self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)])
self.tail = Tail()
def forward(self, img=None, layers=None, pos=None, return_layers=False):
assert img is not None or layers is not None
if img is not None:
bs = img.shape[0]
layers = self.backbone(img)
else:
bs = layers[0].shape[0]
if return_layers:
return layers
# A single position can be optionally given, in which case we need to expand it to represent the entire input.
if pos.shape == (2,):
pos = pos.unsqueeze(0).repeat(bs, 1)
set = []
pos = pos // 4
for layer_out, dilator in zip(layers, self.dilators):
for subdilator in dilator:
set.append(subdilator(layer_out, pos))
pos = pos // 2
# The torch transformer expects the set dimension to be 0.
set = torch.stack(set, dim=0)
set = self.token_position_encoder(set)
set = self.transformer_layers(set)
return self.tail(set)
@register_model
def register_segformer(opt_net, opt):
return Segformer()
if __name__ == '__main__':
model = Segformer().to('cuda')
for j in tqdm(range(1000)):
test_tensor = torch.randn(64,3,224,224).cuda()
print(model(img=test_tensor, pos=torch.randint(0,224,(64,2)).cuda()).shape)

@ -1,128 +0,0 @@
# Contains implementations from the Mixture of Experts paper and Switch Transformers
# Implements KeepTopK where k=1 from mixture of experts paper.
import torch
import torch.nn as nn
from models.switched_conv.switched_conv_hard_routing import RouteTop1
from trainer.losses import ConfigurableLoss
class KeepTop1(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
input[mask != 1] = -float('inf')
ctx.save_for_backward(mask)
return input
@staticmethod
def backward(ctx, grad):
import pydevd
pydevd.settrace(suspend=False, trace_only_current_thread=True)
mask = ctx.saved_tensors
grad_input = grad.clone()
grad_input[mask != 1] = 0
return grad_input
class MixtureOfExperts2dRouter(nn.Module):
def __init__(self, num_experts):
super().__init__()
self.num_experts = num_experts
self.wnoise = nn.Parameter(torch.zeros(1, num_experts, 1, 1))
self.wg = nn.Parameter(torch.zeros(1, num_experts, 1, 1))
def forward(self, x):
wg = x * self.wg
wnoise = nn.functional.softplus(x * self.wnoise)
H = wg + torch.randn_like(x) * wnoise
# Produce the load-balancing loss.
eye = torch.eye(self.num_experts, device=x.device).view(1, self.num_experts, self.num_experts, 1, 1)
mask = torch.abs(1 - eye)
b, c, h, w = H.shape
ninf = torch.zeros_like(eye)
ninf[eye == 1] = -float('inf')
H_masked = H.view(b, c, 1, h,
w) * mask + ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes.
max_excluding = torch.max(H_masked, dim=2)[0]
# load_loss and G are stored as local members to facilitate their use by hard routing regularization losses.
# this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below.
self.load_loss = torch.erf((wg - max_excluding) / wnoise)
# self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so:
self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream.
return self.G
# Retrieve the locally stored loss values and delete them from membership (so as to not waste memory)
def reset(self):
G, load = self.G, self.load_loss
del self.G
del self.load_loss
return G, load
# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses.
class MixtureOfExpertsLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.routers = [] # This is filled in during the first forward() pass and cached from there.
self.first_forward_encountered = False
self.load_weight = opt['load_weight']
self.importance_weight = opt['importance_weight']
def forward(self, net, state):
if not self.first_forward_encountered:
for m in net.modules():
if isinstance(m, MixtureOfExperts2dRouter):
self.routers.append(m)
self.first_forward_encountered = True
l_importance = 0
l_load = 0
for r in self.routers:
G, L = r.reset()
l_importance += G.var().square()
l_load += L.var().square()
return l_importance * self.importance_weight + l_load * self.load_weight
class SwitchTransformersLoadBalancer(nn.Module):
def __init__(self):
super().__init__()
self.norm = SwitchNorm(8, accumulator_size=256)
def forward(self, x):
self.soft = self.norm(nn.functional.softmax(x, dim=1))
self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream.
return self.hard
def reset(self):
soft, hard = self.soft, self.hard
del self.soft, self.hard
return soft, hard
class SwitchTransformersLoadBalancingLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.routers = [] # This is filled in during the first forward() pass and cached from there.
self.first_forward_encountered = False
def forward(self, net, state):
if not self.first_forward_encountered:
for m in net.modules():
if isinstance(m, SwitchTransformersLoadBalancer):
self.routers.append(m)
self.first_forward_encountered = True
loss = 0
for r in self.routers:
soft, hard = r.reset()
N = hard.shape[1]
h_mean = hard.mean(dim=[0,2,3])
s_mean = soft.mean(dim=[0,2,3])
loss += torch.dot(h_mean, s_mean) * N
return loss

@ -1,135 +0,0 @@
import functools
import math
from collections import OrderedDict
import torch
import torch.nn as nn
from lambda_networks import LambdaLayer
from torch.nn import init, Conv2d
import torch.nn.functional as F
class SwitchedConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
switch_breadth: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
coupler_mode: str = 'standard',
coupler_dim_in: int = 0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.padding_mode = padding_mode
self.groups = groups
if include_coupler:
if coupler_mode == 'standard':
self.coupler = Conv2d(coupler_dim_in, switch_breadth, kernel_size=1)
elif coupler_mode == 'lambda':
self.coupler = LambdaLayer(dim=coupler_dim_in, dim_out=switch_breadth, r=23, dim_k=16, heads=2, dim_u=1)
else:
self.coupler = None
self.weights = nn.ParameterList([nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) for _ in range(switch_breadth)])
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
for w in self.weights:
init.kaiming_uniform_(w, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weights[0])
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, inp, selector=None):
if self.coupler:
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
selector = inp
selector = F.softmax(self.coupler(selector), dim=1)
self.last_select = selector.detach().clone()
out_shape = [s // self.stride for s in inp.shape[2:]]
if selector.shape[2] != out_shape[0] or selector.shape[3] != out_shape[1]:
selector = F.interpolate(selector, size=out_shape, mode="nearest")
assert selector is not None
conv_results = []
for i, w in enumerate(self.weights):
conv_results.append(F.conv2d(inp, w, self.bias, self.stride, self.padding, self.dilation, self.groups) * selector[:, i].unsqueeze(1))
return torch.stack(conv_results, dim=-1).sum(dim=-1)
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
# with the equivalent SwitchedConv.weight parameters. Does not create coupler params.
def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_list=[]):
state_dict = module.state_dict()
for name, m in module.named_modules():
ignored = False
for smod in ignore_list:
if smod in name:
ignored = True
continue
if ignored:
continue
if isinstance(m, nn.Conv2d):
if name == '':
basename = 'weight'
modname = 'weights'
else:
basename = f'{name}.weight'
modname = f'{name}.weights'
cnv_weights = state_dict[basename]
del state_dict[basename]
for j in range(switch_breadth):
state_dict[f'{modname}.{j}'] = cnv_weights.clone()
return state_dict
def test_net():
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
mod_conv = SwitchedConv(32, 64, 3, switch_breadth=8, stride=2, padding=1, bias=True, include_coupler=True, coupler_dim_in=128).to('cuda')
mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8)
mod_conv.load_state_dict(mod_sd, strict=False)
inp = torch.randn((8,32,128,128), device='cuda')
sel = torch.randn((8,128,32,32), device='cuda')
out1 = base_conv(inp)
out2 = mod_conv(inp, sel)
assert(torch.max(torch.abs(out1-out2)) < 1e-6)
def perform_conversion():
sd = torch.load("../experiments/rrdb_imgset_226500_generator.pth")
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in sd.items():
if k.startswith('module.'):
load_net_clean[k.replace('module.', '')] = v
else:
load_net_clean[k] = v
sd = load_net_clean
import models.RRDBNet_arch as rrdb
block = functools.partial(rrdb.RRDBWithBypass)
mod = rrdb.RRDBNet(in_channels=3, out_channels=3,
mid_channels=64, num_blocks=23, body_block=block, scale=2, initial_stride=2)
mod.load_state_dict(sd)
converted = convert_conv_net_state_dict_to_switched_conv(mod, 8, ['body.','conv_first','resnet_encoder'])
torch.save(converted, "../experiments/rrdb_imgset_226500_generator_converted.pth")
if __name__ == '__main__':
perform_conversion()

@ -1,360 +0,0 @@
import math
import torch
import torch.nn as nn
from lambda_networks import LambdaLayer
from torch.nn import init, Conv2d, MSELoss, ZeroPad2d
import torch.nn.functional as F
from tqdm import tqdm
import torch.distributed as dist
from trainer.losses import ConfigurableLoss
def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1):
convs = []
b, s, h, w = selector.shape
for sel in range(s):
convs.append(F.conv2d(input, weight[:, :, sel, :, :], bias, stride=stride, padding=weight.shape[-1] // 2))
output = torch.stack(convs, dim=1) * selector.unsqueeze(dim=2)
return output.sum(dim=1)
class SwitchedConvHardRoutingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, selector, weight, bias, stride=1):
# Pre-pad the input.
input = ZeroPad2d(weight.shape[-1]//2)(input)
# Build hard attention mask from selector input
b, s, h, w = selector.shape
mask = selector.argmax(dim=1).int()
import switched_conv_cuda_naive
output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
ctx.stride = stride
ctx.breadth = s
ctx.save_for_backward(*[input, output.detach().clone(), mask, weight, bias])
return output
@staticmethod
def backward(ctx, gradIn):
#import pydevd
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
input, output, mask, weight, bias = ctx.saved_tensors
gradIn = gradIn
# Selector grad is simply the element-wise product of grad with the output of the layer, summed across the channel dimension
# and repeated along the breadth of the switch. (Think of the forward operation using the selector as a simple matrix of 1s
# and zeros that is multiplied by the output.)
grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1)
import switched_conv_cuda_naive
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride)
# Remove input padding from grad
padding = weight.shape[-1] // 2
if padding > 0:
grad = grad[:,:,padding:-padding,padding:-padding]
return grad, grad_sel, grad_w, grad_b, None
class RouteTop1(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1])
if len(input.shape) > 2:
mask = mask.permute(0, 3, 1, 2) # TODO: Make this more extensible.
out = torch.ones_like(input)
out[mask != 1] = 0
ctx.save_for_backward(mask, input.clone())
return out
@staticmethod
def backward(ctx, grad):
# Enable breakpoints in this function: (Comment out if not debugging)
#import pydevd
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
mask, input = ctx.saved_tensors
input[mask != 1] = 1
grad_input = grad.clone()
grad_input[mask != 1] = 0
grad_input_n = grad_input / input # Above, we made everything either a zero or a one. Unscale the ones by dividing by the unmasked inputs.
return grad_input_n
"""
SwitchNorm is meant to be applied against the Softmax output of a switching function across a large set of
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
of switch weights that are over-used and increasing the magnitude of under-used weights.
The return value has the exact same format as a normal Softmax output and can be used directly into the input of an
switch equation.
Since the whole point of convolutional switch is to enable training extra-wide networks to operate on a large number
of image categories, it makes almost no sense to perform this type of norm against a single mini-batch of images: some
of the switches will not be used in such a small context - and that's good! This is solved by accumulating. Every
forward pass computes a norm across the current minibatch. That norm is added into a rotating buffer of size
<accumulator_size>. The actual normalization occurs across the entire rotating buffer.
You should set accumulator size according to two factors:
- Your batch size. Smaller batch size should mean greater accumulator size.
- Your image diversity. More diverse images have less need for the accumulator.
- How wide your switch/switching group size is. More groups mean you're going to want more accumulation.
Note: This norm makes the (potentially flawed) assumption that each forward() pass has unique data. For maximum
effectiveness, avoid doing this - or make alterations to work around it.
Note: This norm does nothing for the first <accumulator_size> iterations.
"""
class SwitchNorm(nn.Module):
def __init__(self, group_size, accumulator_size=128):
super().__init__()
self.accumulator_desired_size = accumulator_size
self.group_size = group_size
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size))
def add_norm_to_buffer(self, x):
flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)]
flat = x.sum(dim=flatten_dims)
norm = flat / torch.mean(flat)
self.accumulator[self.accumulator_index] = norm.detach().clone()
self.accumulator_index += 1
if self.accumulator_index >= self.accumulator_desired_size:
self.accumulator_index *= 0
if self.accumulator_filled <= 0:
self.accumulator_filled += 1
# Input into forward is a switching tensor of shape (batch,groups,<misc>)
def forward(self, x: torch.Tensor, update_attention_norm=True):
assert len(x.shape) >= 2
# Push the accumulator to the right device on the first iteration.
if self.accumulator.device != x.device:
self.accumulator = self.accumulator.to(x.device)
# In eval, don't change the norm buffer.
if self.training and update_attention_norm:
self.add_norm_to_buffer(x)
# Reduce across all distributed entities, if needed
if dist.is_available() and dist.is_initialized():
dist.all_reduce(self.accumulator, op=dist.ReduceOp.SUM)
self.accumulator /= dist.get_world_size()
# Compute the norm factor.
if self.accumulator_filled > 0:
norm = torch.mean(self.accumulator, dim=0)
norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here.
else:
norm = torch.ones(self.group_size, device=self.accumulator.device)
norm = norm.view(1,-1)
while len(x.shape) > len(norm.shape):
norm = norm.unsqueeze(-1)
x = x / norm
return x
class HardRoutingGate(nn.Module):
def __init__(self, breadth, hard_en=True):
super().__init__()
self.norm = SwitchNorm(breadth, accumulator_size=256)
self.hard_en = hard_en
def forward(self, x):
soft = self.norm(nn.functional.softmax(x, dim=1))
if self.hard_en:
return RouteTop1.apply(soft)
return soft
class SwitchedConvHardRouting(nn.Module):
def __init__(self,
in_c,
out_c,
kernel_sz,
breadth,
stride=1,
bias=True,
dropout_rate=0.0,
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
coupler_mode: str = 'standard',
coupler_dim_in: int = 0,
hard_en=True, # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention.
emulate_swconv=True, # When set, performs a nn.Conv2d operation for each breadth. When false, uses the native cuda implementation which computes all switches concurrently.
):
super().__init__()
self.in_channels = in_c
self.out_channels = out_c
self.kernel_size = kernel_sz
self.stride = stride
self.has_bias = bias
self.breadth = breadth
self.dropout_rate = dropout_rate
if include_coupler:
if coupler_mode == 'standard':
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
elif coupler_mode == 'lambda':
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
nn.BatchNorm2d(coupler_dim_in),
nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.BatchNorm2d(breadth),
nn.ReLU(),
Conv2d(breadth, breadth, 1, stride=self.stride))
elif coupler_mode == 'lambda2':
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=coupler_dim_in, r=23, dim_k=16, heads=2, dim_u=1),
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.GroupNorm(num_groups=1, num_channels=breadth),
nn.ReLU(),
Conv2d(breadth, breadth, 1, stride=self.stride))
else:
self.coupler = None
self.gate = HardRoutingGate(breadth, hard_en=True)
self.hard_en = hard_en
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
if bias:
self.bias = nn.Parameter(torch.empty(out_c))
else:
self.bias = torch.zeros(out_c)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight[:,:,0,:,:])
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def load_weights_from_conv(self, cnv):
sd = cnv.state_dict()
sd['weight'] = sd['weight'].unsqueeze(2).repeat(1,1,self.breadth,1,1)
self.load_state_dict(sd)
def forward(self, input, selector=None):
if self.bias.device != input.device:
self.bias = self.bias.to(input.device) # Because this bias can be a tensor that is not moved with the rest of the module.
# If a coupler was specified, run that to convert selector into a softmax distribution.
if self.coupler:
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
selector = input
selector = self.coupler(selector)
assert selector is not None
# Apply dropout at the batch level per kernel.
if self.training and self.dropout_rate > 0:
b, c, h, w = selector.shape
drop = torch.rand((b, c, 1, 1), device=input.device) > self.dropout_rate
# Ensure that there is always at least one switch left un-dropped out
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, c, 1, 1)
drop = drop.logical_or(fix_blank)
selector = drop * selector
selector = self.gate(selector)
# Debugging variables
self.last_select = selector.detach().clone()
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
if self.hard_en:
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
else:
# This composes the switching functionality using raw Torch, which basically consists of computing each of <breadth> convs separately and combining them.
return SwitchedConvRoutingNormal(input, selector, self.weight, self.bias, self.stride)
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
# with the equivalent SwitchedConv.weight parameters. Does not create coupler params.
def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_list=[]):
state_dict = module.state_dict()
for name, m in module.named_modules():
if not isinstance(m, nn.Conv2d):
continue
ignored = False
for smod in ignore_list:
if smod in name:
ignored = True
continue
if ignored:
continue
if name == '':
key = 'weight'
else:
key = f'{name}.weight'
state_dict[key] = state_dict[key].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
return state_dict
# Given a state_dict and the module that that sd belongs to, strips out the specified Conv2d modules and replaces them
# with equivalent switched_conv modules.
def convert_net_to_switched_conv(module, switch_breadth, allow_list, dropout_rate=0.4, coupler_mode='lambda'):
print("CONVERTING MODEL TO SWITCHED_CONV MODE")
# Next, convert the model itself:
full_paths = [n.split('.') for n in allow_list]
for modpath in full_paths:
mod = module
for sub in modpath[:-1]:
pmod = mod
mod = getattr(mod, sub)
old_conv = getattr(mod, modpath[-1])
new_conv = SwitchedConvHardRouting('.'.join(modpath), old_conv.in_channels, old_conv.out_channels, old_conv.kernel_size[0], switch_breadth, old_conv.stride[0], old_conv.bias,
include_coupler=True, dropout_rate=dropout_rate, coupler_mode=coupler_mode)
new_conv = new_conv.to(old_conv.weight.device)
assert old_conv.dilation == 1 or old_conv.dilation == (1,1) or old_conv.dilation is None
if isinstance(mod, nn.Sequential):
# If we use the standard logic (in the else case) here, it reorders the sequential.
# Instead, extract the OrderedDict from the current sequential, replace the Conv inside that dict, then replace the entire sequential to keep the order.
emods = mod._modules
emods[modpath[-1]] = new_conv
delattr(pmod, modpath[-2])
pmod.add_module(modpath[-2], nn.Sequential(emods))
else:
delattr(mod, modpath[-1])
mod.add_module(modpath[-1], new_conv)
def convert_state_dict_to_switched_conv(sd_file, switch_breadth, allow_list):
save = torch.load(sd_file)
sd = save['state_dict']
converted = 0
for cname in allow_list:
for sn in sd.keys():
if cname in sn and sn.endswith('weight'):
sd[sn] = sd[sn].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
converted += 1
print(f"Converted {converted} parameters.")
torch.save(save, sd_file.replace('.pt', "_converted.pt"))
def test_net():
for j in tqdm(range(100)):
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
mod_conv = SwitchedConvHardRouting(32, 64, 3, breadth=8, stride=2, bias=True, include_coupler=True, coupler_dim_in=32, dropout_rate=.2).to('cuda')
mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8)
mod_conv.load_state_dict(mod_sd, strict=False)
inp = torch.randn((128, 32, 128, 128), device='cuda')
out1 = base_conv(inp)
out2 = mod_conv(inp, None)
compare = (out2+torch.rand_like(out2)*1e-6).detach()
MSELoss()(out2, compare).backward()
assert(torch.max(torch.abs(out1-out2)) < 1e-5)
if __name__ == '__main__':
test_net()

@ -4,7 +4,7 @@ import torch
import torch.nn.functional as F
from data.util import is_wav_file, find_files_of_type
from models.audio_resnet import resnet50
from models.audio.audio_resnet import resnet50
from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict

@ -1,6 +1,5 @@
import torch
from models.spinenet_arch import SpineNet
def extract_byol_model_from_state_dict(sd):
wrap_key = 'online_encoder.net.'

@ -1,5 +1,4 @@
import os
import shutil
import torch
import torch.nn as nn
@ -11,14 +10,12 @@ from torchvision.transforms import ToTensor, Resize
from tqdm import tqdm
import numpy as np
import utils
from data.image_folder_dataset import ImageFolderDataset
from models.spinenet_arch import SpineNet
from models.image_latents.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
from utils import util
from utils.options import dict_to_nonedict

@ -13,7 +13,7 @@ import torch
import numpy as np
from torchvision import utils
from models.stylegan.stylegan2_rosinality import Generator, Discriminator
from models.image_generation.stylegan.stylegan2_rosinality import Generator, Discriminator
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
@ -237,7 +237,6 @@ if __name__ == "__main__":
args = parser.parse_args()
sys.path.append('scripts\\stylegan2')
import dnnlib
from dnnlib.tflib.network import generator, discriminator, gen_ema
with open(args.path, "rb") as f:

@ -1,6 +1,6 @@
from torch.cuda.amp import autocast
from models.stylegan.stylegan2_lucidrains import gradient_penalty
from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks import Resample2d
from trainer.inject import Injector

@ -6,7 +6,7 @@ import trainer.eval.evaluator as evaluator
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
from data.image_folder_dataset import ImageFolderDataset
from models.srflow.flow import GaussianDiag
from models.image_generation.srflow import GaussianDiag
class FlowGaussianNll(evaluator.Evaluator):

@ -16,13 +16,13 @@ def create_loss(opt_loss, env):
from trainer.custom_training_components import create_teco_loss
return create_teco_loss(opt_loss, env)
elif 'stylegan2_' in type:
from models.stylegan import create_stylegan2_loss
from models.image_generation.stylegan import create_stylegan2_loss
return create_stylegan2_loss(opt_loss, env)
elif 'style_sr_' in type:
from models.styled_sr import create_stylesr_loss
return create_stylesr_loss(opt_loss, env)
elif 'lightweight_gan_divergence' == type:
from models.lightweight_gan import LightweightGanDivergenceLoss
from models.image_generation.lightweight_gan import LightweightGanDivergenceLoss
return LightweightGanDivergenceLoss(opt_loss, env)
elif type == 'crossentropy' or type == 'cross_entropy':
return CrossEntropy(opt_loss, env)
@ -401,7 +401,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
if self.gradient_penalty:
# Apply gradient penalty. TODO: migrate this elsewhere.
from models.stylegan.stylegan2_lucidrains import gradient_penalty
from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
gp, gp_structure = gradient_penalty(real[0], d_real, return_structured_grads=True)
self.metrics.append(("gradient_penalty", gp.clone().detach()))

@ -5,8 +5,6 @@ import pkgutil
import sys
from collections import OrderedDict
from inspect import isfunction, getmembers, signature
import torch
import models.feature_arch as feature_arch
logger = logging.getLogger('base')

@ -1,7 +1,6 @@
import torch
from torch import nn
import models.SwitchedResidualGenerator_arch as srg
import models.discriminator_vgg_arch as disc
import models.image_generation.discriminator_vgg_arch as disc
import functools
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]