forked from mrq/DL-Art-School
More spring cleaning
This commit is contained in:
parent
735f6e4640
commit
d186414566
|
@ -9,7 +9,7 @@ from torchvision import transforms
|
|||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
import models.stylegan.stylegan2_lucidrains as sg2
|
||||
import models.image_generation.stylegan.stylegan2_lucidrains as sg2
|
||||
|
||||
|
||||
def convert_transparent_to_rgb(image):
|
||||
|
|
|
@ -1016,3 +1016,17 @@ class FinalUpsampleBlock2x(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
return self.chain(x)
|
||||
|
||||
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
||||
def gather_2d(input, index):
|
||||
b, c, h, w = input.shape
|
||||
nodim = input.view(b, c, h * w)
|
||||
ind_nd = index[:, 0]*w + index[:, 1]
|
||||
ind_nd = ind_nd.unsqueeze(1)
|
||||
ind_nd = ind_nd.repeat((1, c))
|
||||
ind_nd = ind_nd.unsqueeze(2)
|
||||
result = torch.gather(nodim, dim=2, index=ind_nd)
|
||||
result = result.squeeze()
|
||||
if b == 1:
|
||||
result = result.unsqueeze(0)
|
||||
return result
|
|
@ -2,4 +2,5 @@ from models.audio.tts.tacotron2.taco_utils import *
|
|||
from models.audio.tts.tacotron2.text import *
|
||||
from models.audio.tts.tacotron2.tacotron2 import *
|
||||
from models.audio.tts.tacotron2.stft import *
|
||||
from models.audio.tts.tacotron2.layers import *
|
||||
from models.audio.tts.tacotron2.layers import *
|
||||
from models.audio.tts.tacotron2.loss import *
|
|
@ -1,96 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Utilizes pretrained torchvision modules for feature extraction
|
||||
|
||||
class VGGFeatureExtractor(nn.Module):
|
||||
def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True,
|
||||
device=torch.device('cpu')):
|
||||
super(VGGFeatureExtractor, self).__init__()
|
||||
self.use_input_norm = use_input_norm
|
||||
if use_bn:
|
||||
model = torchvision.models.vgg19_bn(pretrained=True)
|
||||
else:
|
||||
model = torchvision.models.vgg19(pretrained=True)
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.feature_layers = feature_layers
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)])
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x, interpolate_factor=1):
|
||||
if interpolate_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
|
||||
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class TrainableVGGFeatureExtractor(nn.Module):
|
||||
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
||||
device=torch.device('cpu')):
|
||||
super(TrainableVGGFeatureExtractor, self).__init__()
|
||||
self.use_input_norm = use_input_norm
|
||||
if use_bn:
|
||||
model = torchvision.models.vgg19_bn(pretrained=False)
|
||||
else:
|
||||
model = torchvision.models.vgg19(pretrained=False)
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||
|
||||
def forward(self, x, interpolate_factor=1):
|
||||
if interpolate_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
|
||||
# Assume input range is [0, 1]
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class WideResnetFeatureExtractor(nn.Module):
|
||||
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
||||
print("Using wide resnet extractor.")
|
||||
super(WideResnetFeatureExtractor, self).__init__()
|
||||
self.use_input_norm = use_input_norm
|
||||
self.model = torchvision.models.wide_resnet50_2(pretrained=True)
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
# No need to BP to variable
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
# Assume input range is [0, 1]
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
x = self.model.conv1(x)
|
||||
x = self.model.bn1(x)
|
||||
x = self.model.relu(x)
|
||||
x = self.model.maxpool(x)
|
||||
x = self.model.layer1(x)
|
||||
x = self.model.layer2(x)
|
||||
x = self.model.layer3(x)
|
||||
return x
|
|
@ -3,13 +3,13 @@ import math
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
from models.RRDBNet_arch import RRDB
|
||||
from models.image_generation.RRDBNet_arch import RRDB
|
||||
from models.arch_util import ConvGnLelu
|
||||
|
||||
|
||||
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
|
||||
from models.glean.stylegan2_latent_bank import Stylegan2LatentBank
|
||||
from models.stylegan.stylegan2_rosinality import EqualLinear
|
||||
from models.image_generation.glean.stylegan2_latent_bank import Stylegan2LatentBank
|
||||
from models.image_generation.stylegan.stylegan2_rosinality import EqualLinear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from models.arch_util import ConvGnLelu
|
||||
from models.stylegan.stylegan2_rosinality import Generator
|
||||
from models.image_generation.stylegan.stylegan2_rosinality import Generator
|
||||
|
||||
|
||||
class Stylegan2LatentBank(nn.Module):
|
|
@ -20,7 +20,7 @@ from torch import nn, einsum
|
|||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.srflow import thops
|
||||
from models.image_generation.srflow import thops
|
||||
|
||||
|
||||
class _ActNorm(nn.Module):
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.srflow import thops
|
||||
from models.srflow.flow import Conv2d, Conv2dZeros
|
||||
from models.image_generation.srflow import thops
|
||||
from models.image_generation.srflow.flow import Conv2d, Conv2dZeros
|
||||
from utils.util import opt_get
|
||||
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
import models.srflow.Permutations
|
||||
import models.srflow.FlowAffineCouplingsAblation
|
||||
import models.srflow.FlowActNorms
|
||||
import models.image_generation.srflow.Permutations
|
||||
import models.image_generation.srflow.FlowAffineCouplingsAblation
|
||||
import models.image_generation.srflow.FlowActNorms
|
||||
|
||||
|
||||
def getConditional(rrdbResults, position):
|
||||
|
@ -46,17 +46,17 @@ class FlowStep(nn.Module):
|
|||
self.acOpt = acOpt
|
||||
|
||||
# 1. actnorm
|
||||
self.actnorm = models.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||
self.actnorm = models.image_generation.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||
|
||||
# 2. permute
|
||||
if flow_permutation == "invconv":
|
||||
self.invconv = models.srflow.Permutations.InvertibleConv1x1(
|
||||
self.invconv = models.image_generation.srflow.Permutations.InvertibleConv1x1(
|
||||
in_channels, LU_decomposed=LU_decomposed)
|
||||
|
||||
# 3. coupling
|
||||
if flow_coupling == "CondAffineSeparatedAndCond":
|
||||
self.affine = models.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||
opt=opt)
|
||||
self.affine = models.image_generation.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||
opt=opt)
|
||||
elif flow_coupling == "noCoupling":
|
||||
pass
|
||||
else:
|
|
@ -2,12 +2,11 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
import models.srflow.Split
|
||||
from models.srflow import flow
|
||||
from models.srflow import thops
|
||||
from models.srflow.Split import Split2d
|
||||
from models.srflow.glow_arch import f_conv2d_bias
|
||||
from models.srflow.FlowStep import FlowStep
|
||||
import models.image_generation.srflow.Split
|
||||
from models.image_generation.srflow import flow
|
||||
from models.image_generation.srflow.Split import Split2d
|
||||
from models.image_generation.srflow.glow_arch import f_conv2d_bias
|
||||
from models.image_generation.srflow.FlowStep import FlowStep
|
||||
from utils.util import opt_get, checkpoint
|
||||
|
||||
|
||||
|
@ -146,8 +145,8 @@ class FlowUpsamplerNet(nn.Module):
|
|||
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
|
||||
|
||||
if t == 'Split2d':
|
||||
split = models.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
||||
split = models.image_generation.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
||||
self.layers.append(split)
|
||||
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
||||
self.C = split.num_channels_pass
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.srflow import thops
|
||||
from models.image_generation.srflow import thops
|
||||
|
||||
|
||||
class InvertibleConv1x1(nn.Module):
|
|
@ -2,7 +2,7 @@ import functools
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.srflow.module_util as mutil
|
||||
import models.image_generation.srflow.module_util as mutil
|
||||
from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
|
@ -4,10 +4,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from models.srflow.RRDBNet_arch import RRDBNet
|
||||
from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.srflow.thops as thops
|
||||
import models.srflow.flow as flow
|
||||
from models.image_generation.srflow.RRDBNet_arch import RRDBNet
|
||||
from models.image_generation.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.image_generation.srflow.thops as thops
|
||||
import models.image_generation.srflow.flow as flow
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.srflow import thops
|
||||
from models.srflow.flow import Conv2dZeros, GaussianDiag
|
||||
from models.image_generation.srflow import thops
|
||||
from models.image_generation.srflow.flow import Conv2dZeros, GaussianDiag
|
||||
from utils.util import opt_get
|
||||
|
||||
|
|
@ -1,9 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from models.srflow.FlowActNorms import ActNorm2d
|
||||
from models.image_generation.srflow.FlowActNorms import ActNorm2d
|
||||
from . import thops
|
||||
|
||||
|
|
@ -2,10 +2,10 @@
|
|||
def create_stylegan2_loss(opt_loss, env):
|
||||
type = opt_loss['type']
|
||||
if type == 'stylegan2_divergence':
|
||||
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||
import models.image_generation.stylegan.stylegan2_lucidrains as stylegan2
|
||||
return stylegan2.StyleGan2DivergenceLoss(opt_loss, env)
|
||||
elif type == 'stylegan2_pathlen':
|
||||
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||
import models.image_generation.stylegan.stylegan2_lucidrains as stylegan2
|
||||
return stylegan2.StyleGan2PathLengthLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
from torch import nn
|
||||
|
||||
from data.byol_attachment import reconstructed_shared_regions
|
||||
from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
|
||||
from models.image_latents.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
|
||||
update_moving_average
|
||||
from trainer.networks import create_model, register_model
|
||||
from utils.util import checkpoint
|
|
@ -1,123 +0,0 @@
|
|||
# A direct copy of torchvision's resnet.py modified to support gradient checkpointing.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
class Backbone(torchvision.models.resnet.ResNet):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None):
|
||||
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
||||
replace_stride_with_dilation, norm_layer)
|
||||
del self.fc
|
||||
del self.avgpool
|
||||
|
||||
def _forward_impl(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
l1 = checkpoint(self.layer1, x)
|
||||
l2 = checkpoint(self.layer2, l1)
|
||||
l3 = checkpoint(self.layer3, l2)
|
||||
l4 = checkpoint(self.layer4, l3)
|
||||
|
||||
return l1, l2, l3, l4
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _backbone(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = Backbone(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def backbone18(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _backbone('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def backbone34(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _backbone('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def backbone50(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _backbone('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def backbone101(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _backbone('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def backbone152(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
from models.segformer.backbone import backbone50
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
||||
def gather_2d(input, index):
|
||||
b, c, h, w = input.shape
|
||||
nodim = input.view(b, c, h * w)
|
||||
ind_nd = index[:, 0]*w + index[:, 1]
|
||||
ind_nd = ind_nd.unsqueeze(1)
|
||||
ind_nd = ind_nd.repeat((1, c))
|
||||
ind_nd = ind_nd.unsqueeze(2)
|
||||
result = torch.gather(nodim, dim=2, index=ind_nd)
|
||||
result = result.squeeze()
|
||||
if b == 1:
|
||||
result = result.unsqueeze(0)
|
||||
return result
|
||||
|
||||
|
||||
class DilatorModule(nn.Module):
|
||||
def __init__(self, input_channels, output_channels, max_dilation):
|
||||
super().__init__()
|
||||
self.max_dilation = max_dilation
|
||||
self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=1, bias=True)
|
||||
if max_dilation > 1:
|
||||
self.bn = nn.BatchNorm2d(input_channels)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=max_dilation, dilation=max_dilation, bias=True)
|
||||
self.dense = nn.Linear(input_channels, output_channels, bias=True)
|
||||
|
||||
def forward(self, inp, loc):
|
||||
x = self.conv1(inp)
|
||||
if self.max_dilation > 1:
|
||||
x = self.bn(self.relu(x))
|
||||
x = self.conv2(x)
|
||||
|
||||
# This can be made more efficient by only computing these convolutions across a subset of the image. Possibly.
|
||||
x = gather_2d(x, loc).contiguous()
|
||||
return self.dense(x)
|
||||
|
||||
|
||||
# Grabbed from torch examples: https://github.com/pytorch/examples/tree/master/https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65:7
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:x.size(0), :]
|
||||
return x
|
||||
|
||||
|
||||
# Simple mean() layer encoded into a class so that BYOL can grab it.
|
||||
class Tail(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.mean(dim=0)
|
||||
|
||||
|
||||
class Segformer(nn.Module):
|
||||
def __init__(self, latent_channels=1024, layers=8):
|
||||
super().__init__()
|
||||
self.backbone = backbone50()
|
||||
backbone_channels = [256, 512, 1024, 2048]
|
||||
dilations = [[1,2,3,4],[1,2,3],[1,2],[1]]
|
||||
final_latent_channels = latent_channels
|
||||
dilators = []
|
||||
for ic, dis in zip(backbone_channels, dilations):
|
||||
layer_dilators = []
|
||||
for di in dis:
|
||||
layer_dilators.append(DilatorModule(ic, final_latent_channels, di))
|
||||
dilators.append(nn.ModuleList(layer_dilators))
|
||||
self.dilators = nn.ModuleList(dilators)
|
||||
|
||||
self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10)
|
||||
self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)])
|
||||
self.tail = Tail()
|
||||
|
||||
def forward(self, img=None, layers=None, pos=None, return_layers=False):
|
||||
assert img is not None or layers is not None
|
||||
if img is not None:
|
||||
bs = img.shape[0]
|
||||
layers = self.backbone(img)
|
||||
else:
|
||||
bs = layers[0].shape[0]
|
||||
if return_layers:
|
||||
return layers
|
||||
|
||||
# A single position can be optionally given, in which case we need to expand it to represent the entire input.
|
||||
if pos.shape == (2,):
|
||||
pos = pos.unsqueeze(0).repeat(bs, 1)
|
||||
|
||||
set = []
|
||||
pos = pos // 4
|
||||
for layer_out, dilator in zip(layers, self.dilators):
|
||||
for subdilator in dilator:
|
||||
set.append(subdilator(layer_out, pos))
|
||||
pos = pos // 2
|
||||
|
||||
# The torch transformer expects the set dimension to be 0.
|
||||
set = torch.stack(set, dim=0)
|
||||
set = self.token_position_encoder(set)
|
||||
set = self.transformer_layers(set)
|
||||
return self.tail(set)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_segformer(opt_net, opt):
|
||||
return Segformer()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Segformer().to('cuda')
|
||||
for j in tqdm(range(1000)):
|
||||
test_tensor = torch.randn(64,3,224,224).cuda()
|
||||
print(model(img=test_tensor, pos=torch.randint(0,224,(64,2)).cuda()).shape)
|
|
@ -1,128 +0,0 @@
|
|||
# Contains implementations from the Mixture of Experts paper and Switch Transformers
|
||||
|
||||
|
||||
# Implements KeepTopK where k=1 from mixture of experts paper.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.switched_conv.switched_conv_hard_routing import RouteTop1
|
||||
from trainer.losses import ConfigurableLoss
|
||||
|
||||
|
||||
class KeepTop1(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
|
||||
input[mask != 1] = -float('inf')
|
||||
ctx.save_for_backward(mask)
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
import pydevd
|
||||
pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||
mask = ctx.saved_tensors
|
||||
grad_input = grad.clone()
|
||||
grad_input[mask != 1] = 0
|
||||
return grad_input
|
||||
|
||||
class MixtureOfExperts2dRouter(nn.Module):
|
||||
def __init__(self, num_experts):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.wnoise = nn.Parameter(torch.zeros(1, num_experts, 1, 1))
|
||||
self.wg = nn.Parameter(torch.zeros(1, num_experts, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
wg = x * self.wg
|
||||
wnoise = nn.functional.softplus(x * self.wnoise)
|
||||
H = wg + torch.randn_like(x) * wnoise
|
||||
|
||||
# Produce the load-balancing loss.
|
||||
eye = torch.eye(self.num_experts, device=x.device).view(1, self.num_experts, self.num_experts, 1, 1)
|
||||
mask = torch.abs(1 - eye)
|
||||
b, c, h, w = H.shape
|
||||
ninf = torch.zeros_like(eye)
|
||||
ninf[eye == 1] = -float('inf')
|
||||
H_masked = H.view(b, c, 1, h,
|
||||
w) * mask + ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes.
|
||||
max_excluding = torch.max(H_masked, dim=2)[0]
|
||||
|
||||
# load_loss and G are stored as local members to facilitate their use by hard routing regularization losses.
|
||||
# this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below.
|
||||
self.load_loss = torch.erf((wg - max_excluding) / wnoise)
|
||||
# self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so:
|
||||
self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream.
|
||||
|
||||
return self.G
|
||||
|
||||
# Retrieve the locally stored loss values and delete them from membership (so as to not waste memory)
|
||||
def reset(self):
|
||||
G, load = self.G, self.load_loss
|
||||
del self.G
|
||||
del self.load_loss
|
||||
return G, load
|
||||
|
||||
|
||||
# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses.
|
||||
class MixtureOfExpertsLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
||||
self.first_forward_encountered = False
|
||||
self.load_weight = opt['load_weight']
|
||||
self.importance_weight = opt['importance_weight']
|
||||
|
||||
def forward(self, net, state):
|
||||
if not self.first_forward_encountered:
|
||||
for m in net.modules():
|
||||
if isinstance(m, MixtureOfExperts2dRouter):
|
||||
self.routers.append(m)
|
||||
self.first_forward_encountered = True
|
||||
|
||||
l_importance = 0
|
||||
l_load = 0
|
||||
for r in self.routers:
|
||||
G, L = r.reset()
|
||||
l_importance += G.var().square()
|
||||
l_load += L.var().square()
|
||||
return l_importance * self.importance_weight + l_load * self.load_weight
|
||||
|
||||
|
||||
class SwitchTransformersLoadBalancer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.norm = SwitchNorm(8, accumulator_size=256)
|
||||
|
||||
def forward(self, x):
|
||||
self.soft = self.norm(nn.functional.softmax(x, dim=1))
|
||||
self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream.
|
||||
return self.hard
|
||||
|
||||
def reset(self):
|
||||
soft, hard = self.soft, self.hard
|
||||
del self.soft, self.hard
|
||||
return soft, hard
|
||||
|
||||
|
||||
class SwitchTransformersLoadBalancingLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
||||
self.first_forward_encountered = False
|
||||
|
||||
def forward(self, net, state):
|
||||
if not self.first_forward_encountered:
|
||||
for m in net.modules():
|
||||
if isinstance(m, SwitchTransformersLoadBalancer):
|
||||
self.routers.append(m)
|
||||
self.first_forward_encountered = True
|
||||
|
||||
loss = 0
|
||||
for r in self.routers:
|
||||
soft, hard = r.reset()
|
||||
N = hard.shape[1]
|
||||
h_mean = hard.mean(dim=[0,2,3])
|
||||
s_mean = soft.mean(dim=[0,2,3])
|
||||
loss += torch.dot(h_mean, s_mean) * N
|
||||
return loss
|
|
@ -1,135 +0,0 @@
|
|||
import functools
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lambda_networks import LambdaLayer
|
||||
from torch.nn import init, Conv2d
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SwitchedConv(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
switch_breadth: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
||||
coupler_mode: str = 'standard',
|
||||
coupler_dim_in: int = 0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.padding_mode = padding_mode
|
||||
self.groups = groups
|
||||
|
||||
if include_coupler:
|
||||
if coupler_mode == 'standard':
|
||||
self.coupler = Conv2d(coupler_dim_in, switch_breadth, kernel_size=1)
|
||||
elif coupler_mode == 'lambda':
|
||||
self.coupler = LambdaLayer(dim=coupler_dim_in, dim_out=switch_breadth, r=23, dim_k=16, heads=2, dim_u=1)
|
||||
|
||||
else:
|
||||
self.coupler = None
|
||||
|
||||
self.weights = nn.ParameterList([nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) for _ in range(switch_breadth)])
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
for w in self.weights:
|
||||
init.kaiming_uniform_(w, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weights[0])
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, inp, selector=None):
|
||||
if self.coupler:
|
||||
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
|
||||
selector = inp
|
||||
selector = F.softmax(self.coupler(selector), dim=1)
|
||||
self.last_select = selector.detach().clone()
|
||||
out_shape = [s // self.stride for s in inp.shape[2:]]
|
||||
if selector.shape[2] != out_shape[0] or selector.shape[3] != out_shape[1]:
|
||||
selector = F.interpolate(selector, size=out_shape, mode="nearest")
|
||||
assert selector is not None
|
||||
|
||||
conv_results = []
|
||||
for i, w in enumerate(self.weights):
|
||||
conv_results.append(F.conv2d(inp, w, self.bias, self.stride, self.padding, self.dilation, self.groups) * selector[:, i].unsqueeze(1))
|
||||
return torch.stack(conv_results, dim=-1).sum(dim=-1)
|
||||
|
||||
|
||||
|
||||
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
|
||||
# with the equivalent SwitchedConv.weight parameters. Does not create coupler params.
|
||||
def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_list=[]):
|
||||
state_dict = module.state_dict()
|
||||
for name, m in module.named_modules():
|
||||
ignored = False
|
||||
for smod in ignore_list:
|
||||
if smod in name:
|
||||
ignored = True
|
||||
continue
|
||||
if ignored:
|
||||
continue
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if name == '':
|
||||
basename = 'weight'
|
||||
modname = 'weights'
|
||||
else:
|
||||
basename = f'{name}.weight'
|
||||
modname = f'{name}.weights'
|
||||
cnv_weights = state_dict[basename]
|
||||
del state_dict[basename]
|
||||
for j in range(switch_breadth):
|
||||
state_dict[f'{modname}.{j}'] = cnv_weights.clone()
|
||||
return state_dict
|
||||
|
||||
|
||||
def test_net():
|
||||
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
|
||||
mod_conv = SwitchedConv(32, 64, 3, switch_breadth=8, stride=2, padding=1, bias=True, include_coupler=True, coupler_dim_in=128).to('cuda')
|
||||
mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8)
|
||||
mod_conv.load_state_dict(mod_sd, strict=False)
|
||||
inp = torch.randn((8,32,128,128), device='cuda')
|
||||
sel = torch.randn((8,128,32,32), device='cuda')
|
||||
out1 = base_conv(inp)
|
||||
out2 = mod_conv(inp, sel)
|
||||
assert(torch.max(torch.abs(out1-out2)) < 1e-6)
|
||||
|
||||
def perform_conversion():
|
||||
sd = torch.load("../experiments/rrdb_imgset_226500_generator.pth")
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in sd.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k.replace('module.', '')] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
sd = load_net_clean
|
||||
import models.RRDBNet_arch as rrdb
|
||||
block = functools.partial(rrdb.RRDBWithBypass)
|
||||
mod = rrdb.RRDBNet(in_channels=3, out_channels=3,
|
||||
mid_channels=64, num_blocks=23, body_block=block, scale=2, initial_stride=2)
|
||||
mod.load_state_dict(sd)
|
||||
converted = convert_conv_net_state_dict_to_switched_conv(mod, 8, ['body.','conv_first','resnet_encoder'])
|
||||
torch.save(converted, "../experiments/rrdb_imgset_226500_generator_converted.pth")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
perform_conversion()
|
|
@ -1,360 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lambda_networks import LambdaLayer
|
||||
from torch.nn import init, Conv2d, MSELoss, ZeroPad2d
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
import torch.distributed as dist
|
||||
|
||||
from trainer.losses import ConfigurableLoss
|
||||
|
||||
|
||||
def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1):
|
||||
convs = []
|
||||
b, s, h, w = selector.shape
|
||||
for sel in range(s):
|
||||
convs.append(F.conv2d(input, weight[:, :, sel, :, :], bias, stride=stride, padding=weight.shape[-1] // 2))
|
||||
output = torch.stack(convs, dim=1) * selector.unsqueeze(dim=2)
|
||||
return output.sum(dim=1)
|
||||
|
||||
|
||||
class SwitchedConvHardRoutingFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, selector, weight, bias, stride=1):
|
||||
# Pre-pad the input.
|
||||
input = ZeroPad2d(weight.shape[-1]//2)(input)
|
||||
|
||||
# Build hard attention mask from selector input
|
||||
b, s, h, w = selector.shape
|
||||
|
||||
mask = selector.argmax(dim=1).int()
|
||||
import switched_conv_cuda_naive
|
||||
output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
|
||||
|
||||
ctx.stride = stride
|
||||
ctx.breadth = s
|
||||
ctx.save_for_backward(*[input, output.detach().clone(), mask, weight, bias])
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradIn):
|
||||
#import pydevd
|
||||
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||
input, output, mask, weight, bias = ctx.saved_tensors
|
||||
gradIn = gradIn
|
||||
|
||||
# Selector grad is simply the element-wise product of grad with the output of the layer, summed across the channel dimension
|
||||
# and repeated along the breadth of the switch. (Think of the forward operation using the selector as a simple matrix of 1s
|
||||
# and zeros that is multiplied by the output.)
|
||||
grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1)
|
||||
|
||||
import switched_conv_cuda_naive
|
||||
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride)
|
||||
|
||||
# Remove input padding from grad
|
||||
padding = weight.shape[-1] // 2
|
||||
if padding > 0:
|
||||
grad = grad[:,:,padding:-padding,padding:-padding]
|
||||
return grad, grad_sel, grad_w, grad_b, None
|
||||
|
||||
|
||||
class RouteTop1(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1])
|
||||
if len(input.shape) > 2:
|
||||
mask = mask.permute(0, 3, 1, 2) # TODO: Make this more extensible.
|
||||
out = torch.ones_like(input)
|
||||
out[mask != 1] = 0
|
||||
ctx.save_for_backward(mask, input.clone())
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
# Enable breakpoints in this function: (Comment out if not debugging)
|
||||
#import pydevd
|
||||
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||
|
||||
mask, input = ctx.saved_tensors
|
||||
input[mask != 1] = 1
|
||||
grad_input = grad.clone()
|
||||
grad_input[mask != 1] = 0
|
||||
grad_input_n = grad_input / input # Above, we made everything either a zero or a one. Unscale the ones by dividing by the unmasked inputs.
|
||||
return grad_input_n
|
||||
|
||||
|
||||
"""
|
||||
SwitchNorm is meant to be applied against the Softmax output of a switching function across a large set of
|
||||
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
|
||||
of switch weights that are over-used and increasing the magnitude of under-used weights.
|
||||
|
||||
The return value has the exact same format as a normal Softmax output and can be used directly into the input of an
|
||||
switch equation.
|
||||
|
||||
Since the whole point of convolutional switch is to enable training extra-wide networks to operate on a large number
|
||||
of image categories, it makes almost no sense to perform this type of norm against a single mini-batch of images: some
|
||||
of the switches will not be used in such a small context - and that's good! This is solved by accumulating. Every
|
||||
forward pass computes a norm across the current minibatch. That norm is added into a rotating buffer of size
|
||||
<accumulator_size>. The actual normalization occurs across the entire rotating buffer.
|
||||
|
||||
You should set accumulator size according to two factors:
|
||||
- Your batch size. Smaller batch size should mean greater accumulator size.
|
||||
- Your image diversity. More diverse images have less need for the accumulator.
|
||||
- How wide your switch/switching group size is. More groups mean you're going to want more accumulation.
|
||||
|
||||
Note: This norm makes the (potentially flawed) assumption that each forward() pass has unique data. For maximum
|
||||
effectiveness, avoid doing this - or make alterations to work around it.
|
||||
Note: This norm does nothing for the first <accumulator_size> iterations.
|
||||
"""
|
||||
class SwitchNorm(nn.Module):
|
||||
def __init__(self, group_size, accumulator_size=128):
|
||||
super().__init__()
|
||||
self.accumulator_desired_size = accumulator_size
|
||||
self.group_size = group_size
|
||||
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
|
||||
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
|
||||
self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size))
|
||||
|
||||
def add_norm_to_buffer(self, x):
|
||||
flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)]
|
||||
flat = x.sum(dim=flatten_dims)
|
||||
norm = flat / torch.mean(flat)
|
||||
|
||||
self.accumulator[self.accumulator_index] = norm.detach().clone()
|
||||
self.accumulator_index += 1
|
||||
if self.accumulator_index >= self.accumulator_desired_size:
|
||||
self.accumulator_index *= 0
|
||||
if self.accumulator_filled <= 0:
|
||||
self.accumulator_filled += 1
|
||||
|
||||
# Input into forward is a switching tensor of shape (batch,groups,<misc>)
|
||||
def forward(self, x: torch.Tensor, update_attention_norm=True):
|
||||
assert len(x.shape) >= 2
|
||||
|
||||
# Push the accumulator to the right device on the first iteration.
|
||||
if self.accumulator.device != x.device:
|
||||
self.accumulator = self.accumulator.to(x.device)
|
||||
|
||||
# In eval, don't change the norm buffer.
|
||||
if self.training and update_attention_norm:
|
||||
self.add_norm_to_buffer(x)
|
||||
|
||||
# Reduce across all distributed entities, if needed
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.all_reduce(self.accumulator, op=dist.ReduceOp.SUM)
|
||||
self.accumulator /= dist.get_world_size()
|
||||
|
||||
# Compute the norm factor.
|
||||
if self.accumulator_filled > 0:
|
||||
norm = torch.mean(self.accumulator, dim=0)
|
||||
norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here.
|
||||
else:
|
||||
norm = torch.ones(self.group_size, device=self.accumulator.device)
|
||||
|
||||
norm = norm.view(1,-1)
|
||||
while len(x.shape) > len(norm.shape):
|
||||
norm = norm.unsqueeze(-1)
|
||||
x = x / norm
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class HardRoutingGate(nn.Module):
|
||||
def __init__(self, breadth, hard_en=True):
|
||||
super().__init__()
|
||||
self.norm = SwitchNorm(breadth, accumulator_size=256)
|
||||
self.hard_en = hard_en
|
||||
|
||||
def forward(self, x):
|
||||
soft = self.norm(nn.functional.softmax(x, dim=1))
|
||||
if self.hard_en:
|
||||
return RouteTop1.apply(soft)
|
||||
return soft
|
||||
|
||||
|
||||
class SwitchedConvHardRouting(nn.Module):
|
||||
def __init__(self,
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_sz,
|
||||
breadth,
|
||||
stride=1,
|
||||
bias=True,
|
||||
dropout_rate=0.0,
|
||||
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
||||
coupler_mode: str = 'standard',
|
||||
coupler_dim_in: int = 0,
|
||||
hard_en=True, # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention.
|
||||
emulate_swconv=True, # When set, performs a nn.Conv2d operation for each breadth. When false, uses the native cuda implementation which computes all switches concurrently.
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_c
|
||||
self.out_channels = out_c
|
||||
self.kernel_size = kernel_sz
|
||||
self.stride = stride
|
||||
self.has_bias = bias
|
||||
self.breadth = breadth
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
if include_coupler:
|
||||
if coupler_mode == 'standard':
|
||||
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
|
||||
elif coupler_mode == 'lambda':
|
||||
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
||||
nn.BatchNorm2d(coupler_dim_in),
|
||||
nn.ReLU(),
|
||||
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
||||
nn.BatchNorm2d(breadth),
|
||||
nn.ReLU(),
|
||||
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||
elif coupler_mode == 'lambda2':
|
||||
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
||||
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
|
||||
nn.ReLU(),
|
||||
LambdaLayer(dim=coupler_dim_in, dim_out=coupler_dim_in, r=23, dim_k=16, heads=2, dim_u=1),
|
||||
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
|
||||
nn.ReLU(),
|
||||
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
||||
nn.GroupNorm(num_groups=1, num_channels=breadth),
|
||||
nn.ReLU(),
|
||||
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||
else:
|
||||
self.coupler = None
|
||||
self.gate = HardRoutingGate(breadth, hard_en=True)
|
||||
self.hard_en = hard_en
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_c))
|
||||
else:
|
||||
self.bias = torch.zeros(out_c)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight[:,:,0,:,:])
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def load_weights_from_conv(self, cnv):
|
||||
sd = cnv.state_dict()
|
||||
sd['weight'] = sd['weight'].unsqueeze(2).repeat(1,1,self.breadth,1,1)
|
||||
self.load_state_dict(sd)
|
||||
|
||||
def forward(self, input, selector=None):
|
||||
if self.bias.device != input.device:
|
||||
self.bias = self.bias.to(input.device) # Because this bias can be a tensor that is not moved with the rest of the module.
|
||||
|
||||
# If a coupler was specified, run that to convert selector into a softmax distribution.
|
||||
if self.coupler:
|
||||
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
|
||||
selector = input
|
||||
selector = self.coupler(selector)
|
||||
assert selector is not None
|
||||
|
||||
# Apply dropout at the batch level per kernel.
|
||||
if self.training and self.dropout_rate > 0:
|
||||
b, c, h, w = selector.shape
|
||||
drop = torch.rand((b, c, 1, 1), device=input.device) > self.dropout_rate
|
||||
# Ensure that there is always at least one switch left un-dropped out
|
||||
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, c, 1, 1)
|
||||
drop = drop.logical_or(fix_blank)
|
||||
selector = drop * selector
|
||||
|
||||
selector = self.gate(selector)
|
||||
|
||||
# Debugging variables
|
||||
self.last_select = selector.detach().clone()
|
||||
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
|
||||
|
||||
if self.hard_en:
|
||||
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
|
||||
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
|
||||
else:
|
||||
# This composes the switching functionality using raw Torch, which basically consists of computing each of <breadth> convs separately and combining them.
|
||||
return SwitchedConvRoutingNormal(input, selector, self.weight, self.bias, self.stride)
|
||||
|
||||
|
||||
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
|
||||
# with the equivalent SwitchedConv.weight parameters. Does not create coupler params.
|
||||
def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_list=[]):
|
||||
state_dict = module.state_dict()
|
||||
for name, m in module.named_modules():
|
||||
if not isinstance(m, nn.Conv2d):
|
||||
continue
|
||||
ignored = False
|
||||
for smod in ignore_list:
|
||||
if smod in name:
|
||||
ignored = True
|
||||
continue
|
||||
if ignored:
|
||||
continue
|
||||
if name == '':
|
||||
key = 'weight'
|
||||
else:
|
||||
key = f'{name}.weight'
|
||||
state_dict[key] = state_dict[key].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
||||
return state_dict
|
||||
|
||||
|
||||
# Given a state_dict and the module that that sd belongs to, strips out the specified Conv2d modules and replaces them
|
||||
# with equivalent switched_conv modules.
|
||||
def convert_net_to_switched_conv(module, switch_breadth, allow_list, dropout_rate=0.4, coupler_mode='lambda'):
|
||||
print("CONVERTING MODEL TO SWITCHED_CONV MODE")
|
||||
|
||||
# Next, convert the model itself:
|
||||
full_paths = [n.split('.') for n in allow_list]
|
||||
for modpath in full_paths:
|
||||
mod = module
|
||||
for sub in modpath[:-1]:
|
||||
pmod = mod
|
||||
mod = getattr(mod, sub)
|
||||
old_conv = getattr(mod, modpath[-1])
|
||||
new_conv = SwitchedConvHardRouting('.'.join(modpath), old_conv.in_channels, old_conv.out_channels, old_conv.kernel_size[0], switch_breadth, old_conv.stride[0], old_conv.bias,
|
||||
include_coupler=True, dropout_rate=dropout_rate, coupler_mode=coupler_mode)
|
||||
new_conv = new_conv.to(old_conv.weight.device)
|
||||
assert old_conv.dilation == 1 or old_conv.dilation == (1,1) or old_conv.dilation is None
|
||||
if isinstance(mod, nn.Sequential):
|
||||
# If we use the standard logic (in the else case) here, it reorders the sequential.
|
||||
# Instead, extract the OrderedDict from the current sequential, replace the Conv inside that dict, then replace the entire sequential to keep the order.
|
||||
emods = mod._modules
|
||||
emods[modpath[-1]] = new_conv
|
||||
delattr(pmod, modpath[-2])
|
||||
pmod.add_module(modpath[-2], nn.Sequential(emods))
|
||||
else:
|
||||
delattr(mod, modpath[-1])
|
||||
mod.add_module(modpath[-1], new_conv)
|
||||
|
||||
|
||||
def convert_state_dict_to_switched_conv(sd_file, switch_breadth, allow_list):
|
||||
save = torch.load(sd_file)
|
||||
sd = save['state_dict']
|
||||
converted = 0
|
||||
for cname in allow_list:
|
||||
for sn in sd.keys():
|
||||
if cname in sn and sn.endswith('weight'):
|
||||
sd[sn] = sd[sn].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
||||
converted += 1
|
||||
print(f"Converted {converted} parameters.")
|
||||
torch.save(save, sd_file.replace('.pt', "_converted.pt"))
|
||||
|
||||
|
||||
def test_net():
|
||||
for j in tqdm(range(100)):
|
||||
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
|
||||
mod_conv = SwitchedConvHardRouting(32, 64, 3, breadth=8, stride=2, bias=True, include_coupler=True, coupler_dim_in=32, dropout_rate=.2).to('cuda')
|
||||
mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8)
|
||||
mod_conv.load_state_dict(mod_sd, strict=False)
|
||||
inp = torch.randn((128, 32, 128, 128), device='cuda')
|
||||
out1 = base_conv(inp)
|
||||
out2 = mod_conv(inp, None)
|
||||
compare = (out2+torch.rand_like(out2)*1e-6).detach()
|
||||
MSELoss()(out2, compare).backward()
|
||||
assert(torch.max(torch.abs(out1-out2)) < 1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_net()
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
from data.util import is_wav_file, find_files_of_type
|
||||
from models.audio_resnet import resnet50
|
||||
from models.audio.audio_resnet import resnet50
|
||||
from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
|
||||
from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
def extract_byol_model_from_state_dict(sd):
|
||||
wrap_key = 'online_encoder.net.'
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -11,14 +10,12 @@ from torchvision.transforms import ToTensor, Resize
|
|||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.spinenet_arch import SpineNet
|
||||
from models.image_latents.spinenet_arch import SpineNet
|
||||
|
||||
|
||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
# and the distance is computed across the channel dimension.
|
||||
from utils import util
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch
|
|||
import numpy as np
|
||||
from torchvision import utils
|
||||
|
||||
from models.stylegan.stylegan2_rosinality import Generator, Discriminator
|
||||
from models.image_generation.stylegan.stylegan2_rosinality import Generator, Discriminator
|
||||
|
||||
|
||||
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
|
||||
|
@ -237,7 +237,6 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
sys.path.append('scripts\\stylegan2')
|
||||
|
||||
import dnnlib
|
||||
from dnnlib.tflib.network import generator, discriminator, gen_ema
|
||||
|
||||
with open(args.path, "rb") as f:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.flownet2.networks import Resample2d
|
||||
from trainer.inject import Injector
|
||||
|
|
|
@ -6,7 +6,7 @@ import trainer.eval.evaluator as evaluator
|
|||
|
||||
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.srflow.flow import GaussianDiag
|
||||
from models.image_generation.srflow import GaussianDiag
|
||||
|
||||
|
||||
class FlowGaussianNll(evaluator.Evaluator):
|
||||
|
|
|
@ -16,13 +16,13 @@ def create_loss(opt_loss, env):
|
|||
from trainer.custom_training_components import create_teco_loss
|
||||
return create_teco_loss(opt_loss, env)
|
||||
elif 'stylegan2_' in type:
|
||||
from models.stylegan import create_stylegan2_loss
|
||||
from models.image_generation.stylegan import create_stylegan2_loss
|
||||
return create_stylegan2_loss(opt_loss, env)
|
||||
elif 'style_sr_' in type:
|
||||
from models.styled_sr import create_stylesr_loss
|
||||
return create_stylesr_loss(opt_loss, env)
|
||||
elif 'lightweight_gan_divergence' == type:
|
||||
from models.lightweight_gan import LightweightGanDivergenceLoss
|
||||
from models.image_generation.lightweight_gan import LightweightGanDivergenceLoss
|
||||
return LightweightGanDivergenceLoss(opt_loss, env)
|
||||
elif type == 'crossentropy' or type == 'cross_entropy':
|
||||
return CrossEntropy(opt_loss, env)
|
||||
|
@ -401,7 +401,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
|
||||
if self.gradient_penalty:
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
||||
gp, gp_structure = gradient_penalty(real[0], d_real, return_structured_grads=True)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
|
|
|
@ -5,8 +5,6 @@ import pkgutil
|
|||
import sys
|
||||
from collections import OrderedDict
|
||||
from inspect import isfunction, getmembers, signature
|
||||
import torch
|
||||
import models.feature_arch as feature_arch
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import models.SwitchedResidualGenerator_arch as srg
|
||||
import models.discriminator_vgg_arch as disc
|
||||
import models.image_generation.discriminator_vgg_arch as disc
|
||||
import functools
|
||||
|
||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
||||
|
|
Loading…
Reference in New Issue
Block a user