From d186414566d52c4476ff921d6ceee7a46254dfce Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 16 Mar 2022 12:04:00 -0600 Subject: [PATCH] More spring cleaning --- codes/data/stylegan2_dataset.py | 2 +- codes/models/arch_util.py | 14 + codes/models/{ => audio}/audio_resnet.py | 0 codes/models/audio/tts/tacotron2/__init__.py | 3 +- codes/models/feature_arch.py | 96 ----- .../{ => image_generation}/RRDBNet_arch.py | 0 .../{ => image_generation}/ResGen_arch.py | 0 .../{byol => image_generation}/__init__.py | 0 .../discriminator_vgg_arch.py | 0 .../glean}/__init__.py | 0 .../{ => image_generation}/glean/glean.py | 6 +- .../glean/stylegan2_latent_bank.py | 2 +- .../{ => image_generation}/lightweight_gan.py | 2 +- .../srflow/FlowActNorms.py | 2 +- .../srflow/FlowAffineCouplingsAblation.py | 4 +- .../{ => image_generation}/srflow/FlowStep.py | 14 +- .../srflow/FlowUpsamplerNet.py | 15 +- .../srflow/Permutations.py | 2 +- .../srflow/RRDBNet_arch.py | 2 +- .../srflow/SRFlowNet_arch.py | 8 +- .../{ => image_generation}/srflow/Split.py | 4 +- .../srflow}/__init__.py | 0 .../{ => image_generation}/srflow/flow.py | 3 +- .../srflow/glow_arch.py | 0 .../srflow/module_util.py | 0 .../{ => image_generation}/srflow/thops.py | 0 .../stylegan/Discriminator_StyleGAN.py | 0 .../stylegan/__init__.py | 4 +- .../stylegan/stylegan2_lucidrains.py | 0 .../stylegan/stylegan2_rosinality.py | 0 .../{segformer => image_latents}/__init__.py | 0 .../byol}/__init__.py | 0 .../byol/byol_model_wrapper.py | 0 .../byol/byol_structural.py | 2 +- .../fixup_resnet/DiscriminatorResnet_arch.py | 0 .../fixup_resnet}/__init__.py | 0 .../{ => image_latents}/spinenet_arch.py | 0 codes/models/segformer/backbone.py | 123 ------ codes/models/segformer/segformer.py | 131 ------- .../switched_conv/mixture_of_experts.py | 128 ------- codes/models/switched_conv/switched_conv.py | 135 ------- .../switched_conv_hard_routing.py | 360 ------------------ codes/scripts/audio/test_audio_similarity.py | 2 +- .../byol/byol_extract_wrapped_model.py | 1 - .../scripts/byol/byol_spinenet_playground.py | 5 +- .../stylegan2/convert_weights_rosinality.py | 3 +- .../tecogan_losses.py | 2 +- codes/trainer/eval/flow_gaussian_nll.py | 2 +- codes/trainer/losses.py | 6 +- codes/trainer/networks.py | 2 - codes/utils/numeric_stability.py | 3 +- 51 files changed, 60 insertions(+), 1028 deletions(-) rename codes/models/{ => audio}/audio_resnet.py (100%) delete mode 100644 codes/models/feature_arch.py rename codes/models/{ => image_generation}/RRDBNet_arch.py (100%) rename codes/models/{ => image_generation}/ResGen_arch.py (100%) rename codes/models/{byol => image_generation}/__init__.py (100%) rename codes/models/{ => image_generation}/discriminator_vgg_arch.py (100%) rename codes/models/{fixup_resnet => image_generation/glean}/__init__.py (100%) rename codes/models/{ => image_generation}/glean/glean.py (96%) rename codes/models/{ => image_generation}/glean/stylegan2_latent_bank.py (97%) rename codes/models/{ => image_generation}/lightweight_gan.py (99%) rename codes/models/{ => image_generation}/srflow/FlowActNorms.py (98%) rename codes/models/{ => image_generation}/srflow/FlowAffineCouplingsAblation.py (97%) rename codes/models/{ => image_generation}/srflow/FlowStep.py (89%) rename codes/models/{ => image_generation}/srflow/FlowUpsamplerNet.py (95%) rename codes/models/{ => image_generation}/srflow/Permutations.py (96%) rename codes/models/{ => image_generation}/srflow/RRDBNet_arch.py (99%) rename codes/models/{ => image_generation}/srflow/SRFlowNet_arch.py (96%) rename codes/models/{ => image_generation}/srflow/Split.py (95%) rename codes/models/{glean => image_generation/srflow}/__init__.py (100%) rename codes/models/{ => image_generation}/srflow/flow.py (98%) rename codes/models/{ => image_generation}/srflow/glow_arch.py (100%) rename codes/models/{ => image_generation}/srflow/module_util.py (100%) rename codes/models/{ => image_generation}/srflow/thops.py (100%) rename codes/models/{ => image_generation}/stylegan/Discriminator_StyleGAN.py (100%) rename codes/models/{ => image_generation}/stylegan/__init__.py (66%) rename codes/models/{ => image_generation}/stylegan/stylegan2_lucidrains.py (100%) rename codes/models/{ => image_generation}/stylegan/stylegan2_rosinality.py (100%) rename codes/models/{segformer => image_latents}/__init__.py (100%) rename codes/models/{srflow => image_latents/byol}/__init__.py (100%) rename codes/models/{ => image_latents}/byol/byol_model_wrapper.py (100%) rename codes/models/{ => image_latents}/byol/byol_structural.py (98%) rename codes/models/{ => image_latents}/fixup_resnet/DiscriminatorResnet_arch.py (100%) rename codes/models/{switched_conv => image_latents/fixup_resnet}/__init__.py (100%) rename codes/models/{ => image_latents}/spinenet_arch.py (100%) delete mode 100644 codes/models/segformer/backbone.py delete mode 100644 codes/models/segformer/segformer.py delete mode 100644 codes/models/switched_conv/mixture_of_experts.py delete mode 100644 codes/models/switched_conv/switched_conv.py delete mode 100644 codes/models/switched_conv/switched_conv_hard_routing.py diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py index 52dde7e2..7f4946a9 100644 --- a/codes/data/stylegan2_dataset.py +++ b/codes/data/stylegan2_dataset.py @@ -9,7 +9,7 @@ from torchvision import transforms import torch.nn as nn from pathlib import Path -import models.stylegan.stylegan2_lucidrains as sg2 +import models.image_generation.stylegan.stylegan2_lucidrains as sg2 def convert_transparent_to_rgb(image): diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 159a012c..425098aa 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -1016,3 +1016,17 @@ class FinalUpsampleBlock2x(nn.Module): def forward(self, x): return self.chain(x) + +# torch.gather() which operates as it always fucking should have: pulling indexes from the input. +def gather_2d(input, index): + b, c, h, w = input.shape + nodim = input.view(b, c, h * w) + ind_nd = index[:, 0]*w + index[:, 1] + ind_nd = ind_nd.unsqueeze(1) + ind_nd = ind_nd.repeat((1, c)) + ind_nd = ind_nd.unsqueeze(2) + result = torch.gather(nodim, dim=2, index=ind_nd) + result = result.squeeze() + if b == 1: + result = result.unsqueeze(0) + return result \ No newline at end of file diff --git a/codes/models/audio_resnet.py b/codes/models/audio/audio_resnet.py similarity index 100% rename from codes/models/audio_resnet.py rename to codes/models/audio/audio_resnet.py diff --git a/codes/models/audio/tts/tacotron2/__init__.py b/codes/models/audio/tts/tacotron2/__init__.py index feeb8381..feca08a9 100644 --- a/codes/models/audio/tts/tacotron2/__init__.py +++ b/codes/models/audio/tts/tacotron2/__init__.py @@ -2,4 +2,5 @@ from models.audio.tts.tacotron2.taco_utils import * from models.audio.tts.tacotron2.text import * from models.audio.tts.tacotron2.tacotron2 import * from models.audio.tts.tacotron2.stft import * -from models.audio.tts.tacotron2.layers import * \ No newline at end of file +from models.audio.tts.tacotron2.layers import * +from models.audio.tts.tacotron2.loss import * \ No newline at end of file diff --git a/codes/models/feature_arch.py b/codes/models/feature_arch.py deleted file mode 100644 index bb2b4371..00000000 --- a/codes/models/feature_arch.py +++ /dev/null @@ -1,96 +0,0 @@ -import torch -import torch.nn as nn -import torchvision -import torch.nn.functional as F - -# Utilizes pretrained torchvision modules for feature extraction - -class VGGFeatureExtractor(nn.Module): - def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True, - device=torch.device('cpu')): - super(VGGFeatureExtractor, self).__init__() - self.use_input_norm = use_input_norm - if use_bn: - model = torchvision.models.vgg19_bn(pretrained=True) - else: - model = torchvision.models.vgg19(pretrained=True) - if self.use_input_norm: - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) - # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) - # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] - self.register_buffer('mean', mean) - self.register_buffer('std', std) - self.feature_layers = feature_layers - self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)]) - # No need to BP to variable - for k, v in self.features.named_parameters(): - v.requires_grad = False - - def forward(self, x, interpolate_factor=1): - if interpolate_factor > 1: - x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic') - - if self.use_input_norm: - x = (x - self.mean) / self.std - output = self.features(x) - return output - - -class TrainableVGGFeatureExtractor(nn.Module): - def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, - device=torch.device('cpu')): - super(TrainableVGGFeatureExtractor, self).__init__() - self.use_input_norm = use_input_norm - if use_bn: - model = torchvision.models.vgg19_bn(pretrained=False) - else: - model = torchvision.models.vgg19(pretrained=False) - if self.use_input_norm: - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) - # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) - # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] - self.register_buffer('mean', mean) - self.register_buffer('std', std) - self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) - - def forward(self, x, interpolate_factor=1): - if interpolate_factor > 1: - x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic') - # Assume input range is [0, 1] - if self.use_input_norm: - x = (x - self.mean) / self.std - output = self.features(x) - return output - - -class WideResnetFeatureExtractor(nn.Module): - def __init__(self, use_input_norm=True, device=torch.device('cpu')): - print("Using wide resnet extractor.") - super(WideResnetFeatureExtractor, self).__init__() - self.use_input_norm = use_input_norm - self.model = torchvision.models.wide_resnet50_2(pretrained=True) - if self.use_input_norm: - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) - # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) - # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] - self.register_buffer('mean', mean) - self.register_buffer('std', std) - # No need to BP to variable - for p in self.model.parameters(): - p.requires_grad = False - - def forward(self, x): - # Assume input range is [0, 1] - if self.use_input_norm: - x = (x - self.mean) / self.std - x = self.model.conv1(x) - x = self.model.bn1(x) - x = self.model.relu(x) - x = self.model.maxpool(x) - x = self.model.layer1(x) - x = self.model.layer2(x) - x = self.model.layer3(x) - return x \ No newline at end of file diff --git a/codes/models/RRDBNet_arch.py b/codes/models/image_generation/RRDBNet_arch.py similarity index 100% rename from codes/models/RRDBNet_arch.py rename to codes/models/image_generation/RRDBNet_arch.py diff --git a/codes/models/ResGen_arch.py b/codes/models/image_generation/ResGen_arch.py similarity index 100% rename from codes/models/ResGen_arch.py rename to codes/models/image_generation/ResGen_arch.py diff --git a/codes/models/byol/__init__.py b/codes/models/image_generation/__init__.py similarity index 100% rename from codes/models/byol/__init__.py rename to codes/models/image_generation/__init__.py diff --git a/codes/models/discriminator_vgg_arch.py b/codes/models/image_generation/discriminator_vgg_arch.py similarity index 100% rename from codes/models/discriminator_vgg_arch.py rename to codes/models/image_generation/discriminator_vgg_arch.py diff --git a/codes/models/fixup_resnet/__init__.py b/codes/models/image_generation/glean/__init__.py similarity index 100% rename from codes/models/fixup_resnet/__init__.py rename to codes/models/image_generation/glean/__init__.py diff --git a/codes/models/glean/glean.py b/codes/models/image_generation/glean/glean.py similarity index 96% rename from codes/models/glean/glean.py rename to codes/models/image_generation/glean/glean.py index 497f0a7a..89ed0b16 100644 --- a/codes/models/glean/glean.py +++ b/codes/models/image_generation/glean/glean.py @@ -3,13 +3,13 @@ import math import torch.nn as nn import torch -from models.RRDBNet_arch import RRDB +from models.image_generation.RRDBNet_arch import RRDB from models.arch_util import ConvGnLelu # Produces a convolutional feature (`f`) and a reduced feature map with double the filters. -from models.glean.stylegan2_latent_bank import Stylegan2LatentBank -from models.stylegan.stylegan2_rosinality import EqualLinear +from models.image_generation.glean.stylegan2_latent_bank import Stylegan2LatentBank +from models.image_generation.stylegan.stylegan2_rosinality import EqualLinear from trainer.networks import register_model from utils.util import checkpoint, sequential_checkpoint diff --git a/codes/models/glean/stylegan2_latent_bank.py b/codes/models/image_generation/glean/stylegan2_latent_bank.py similarity index 97% rename from codes/models/glean/stylegan2_latent_bank.py rename to codes/models/image_generation/glean/stylegan2_latent_bank.py index f7a191ae..e51d48a7 100644 --- a/codes/models/glean/stylegan2_latent_bank.py +++ b/codes/models/image_generation/glean/stylegan2_latent_bank.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from models.arch_util import ConvGnLelu -from models.stylegan.stylegan2_rosinality import Generator +from models.image_generation.stylegan.stylegan2_rosinality import Generator class Stylegan2LatentBank(nn.Module): diff --git a/codes/models/lightweight_gan.py b/codes/models/image_generation/lightweight_gan.py similarity index 99% rename from codes/models/lightweight_gan.py rename to codes/models/image_generation/lightweight_gan.py index 8a069fc8..fc405086 100644 --- a/codes/models/lightweight_gan.py +++ b/codes/models/image_generation/lightweight_gan.py @@ -20,7 +20,7 @@ from torch import nn, einsum from torch.utils.data import Dataset from torchvision import transforms -from models.stylegan.stylegan2_lucidrains import gradient_penalty +from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty from trainer.networks import register_model from utils.util import opt_get diff --git a/codes/models/srflow/FlowActNorms.py b/codes/models/image_generation/srflow/FlowActNorms.py similarity index 98% rename from codes/models/srflow/FlowActNorms.py rename to codes/models/image_generation/srflow/FlowActNorms.py index b37c3485..6ca7b498 100644 --- a/codes/models/srflow/FlowActNorms.py +++ b/codes/models/image_generation/srflow/FlowActNorms.py @@ -1,7 +1,7 @@ import torch from torch import nn as nn -from models.srflow import thops +from models.image_generation.srflow import thops class _ActNorm(nn.Module): diff --git a/codes/models/srflow/FlowAffineCouplingsAblation.py b/codes/models/image_generation/srflow/FlowAffineCouplingsAblation.py similarity index 97% rename from codes/models/srflow/FlowAffineCouplingsAblation.py rename to codes/models/image_generation/srflow/FlowAffineCouplingsAblation.py index 62e8d41e..f8d85d9c 100644 --- a/codes/models/srflow/FlowAffineCouplingsAblation.py +++ b/codes/models/image_generation/srflow/FlowAffineCouplingsAblation.py @@ -1,8 +1,8 @@ import torch from torch import nn as nn -from models.srflow import thops -from models.srflow.flow import Conv2d, Conv2dZeros +from models.image_generation.srflow import thops +from models.image_generation.srflow.flow import Conv2d, Conv2dZeros from utils.util import opt_get diff --git a/codes/models/srflow/FlowStep.py b/codes/models/image_generation/srflow/FlowStep.py similarity index 89% rename from codes/models/srflow/FlowStep.py rename to codes/models/image_generation/srflow/FlowStep.py index b62d4b55..7d3f0724 100644 --- a/codes/models/srflow/FlowStep.py +++ b/codes/models/image_generation/srflow/FlowStep.py @@ -1,9 +1,9 @@ import torch from torch import nn as nn -import models.srflow.Permutations -import models.srflow.FlowAffineCouplingsAblation -import models.srflow.FlowActNorms +import models.image_generation.srflow.Permutations +import models.image_generation.srflow.FlowAffineCouplingsAblation +import models.image_generation.srflow.FlowActNorms def getConditional(rrdbResults, position): @@ -46,17 +46,17 @@ class FlowStep(nn.Module): self.acOpt = acOpt # 1. actnorm - self.actnorm = models.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) + self.actnorm = models.image_generation.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) # 2. permute if flow_permutation == "invconv": - self.invconv = models.srflow.Permutations.InvertibleConv1x1( + self.invconv = models.image_generation.srflow.Permutations.InvertibleConv1x1( in_channels, LU_decomposed=LU_decomposed) # 3. coupling if flow_coupling == "CondAffineSeparatedAndCond": - self.affine = models.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, - opt=opt) + self.affine = models.image_generation.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, + opt=opt) elif flow_coupling == "noCoupling": pass else: diff --git a/codes/models/srflow/FlowUpsamplerNet.py b/codes/models/image_generation/srflow/FlowUpsamplerNet.py similarity index 95% rename from codes/models/srflow/FlowUpsamplerNet.py rename to codes/models/image_generation/srflow/FlowUpsamplerNet.py index 387ebbef..025ddd26 100644 --- a/codes/models/srflow/FlowUpsamplerNet.py +++ b/codes/models/image_generation/srflow/FlowUpsamplerNet.py @@ -2,12 +2,11 @@ import numpy as np import torch from torch import nn as nn -import models.srflow.Split -from models.srflow import flow -from models.srflow import thops -from models.srflow.Split import Split2d -from models.srflow.glow_arch import f_conv2d_bias -from models.srflow.FlowStep import FlowStep +import models.image_generation.srflow.Split +from models.image_generation.srflow import flow +from models.image_generation.srflow.Split import Split2d +from models.image_generation.srflow.glow_arch import f_conv2d_bias +from models.image_generation.srflow.FlowStep import FlowStep from utils.util import opt_get, checkpoint @@ -146,8 +145,8 @@ class FlowUpsamplerNet(nn.Module): t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d') if t == 'Split2d': - split = models.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, - cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) + split = models.image_generation.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, + cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) self.layers.append(split) self.output_shapes.append([-1, split.num_channels_pass, H, W]) self.C = split.num_channels_pass diff --git a/codes/models/srflow/Permutations.py b/codes/models/image_generation/srflow/Permutations.py similarity index 96% rename from codes/models/srflow/Permutations.py rename to codes/models/image_generation/srflow/Permutations.py index ba8d85f4..122ab3a9 100644 --- a/codes/models/srflow/Permutations.py +++ b/codes/models/image_generation/srflow/Permutations.py @@ -3,7 +3,7 @@ import torch from torch import nn as nn from torch.nn import functional as F -from models.srflow import thops +from models.image_generation.srflow import thops class InvertibleConv1x1(nn.Module): diff --git a/codes/models/srflow/RRDBNet_arch.py b/codes/models/image_generation/srflow/RRDBNet_arch.py similarity index 99% rename from codes/models/srflow/RRDBNet_arch.py rename to codes/models/image_generation/srflow/RRDBNet_arch.py index e566a0f2..d34d7160 100644 --- a/codes/models/srflow/RRDBNet_arch.py +++ b/codes/models/image_generation/srflow/RRDBNet_arch.py @@ -2,7 +2,7 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F -import models.srflow.module_util as mutil +import models.image_generation.srflow.module_util as mutil from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu from trainer.networks import register_model from utils.util import opt_get diff --git a/codes/models/srflow/SRFlowNet_arch.py b/codes/models/image_generation/srflow/SRFlowNet_arch.py similarity index 96% rename from codes/models/srflow/SRFlowNet_arch.py rename to codes/models/image_generation/srflow/SRFlowNet_arch.py index 3002dbdc..82cc5432 100644 --- a/codes/models/srflow/SRFlowNet_arch.py +++ b/codes/models/image_generation/srflow/SRFlowNet_arch.py @@ -4,10 +4,10 @@ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -from models.srflow.RRDBNet_arch import RRDBNet -from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet -import models.srflow.thops as thops -import models.srflow.flow as flow +from models.image_generation.srflow.RRDBNet_arch import RRDBNet +from models.image_generation.srflow.FlowUpsamplerNet import FlowUpsamplerNet +import models.image_generation.srflow.thops as thops +import models.image_generation.srflow.flow as flow from trainer.networks import register_model from utils.util import opt_get diff --git a/codes/models/srflow/Split.py b/codes/models/image_generation/srflow/Split.py similarity index 95% rename from codes/models/srflow/Split.py rename to codes/models/image_generation/srflow/Split.py index c3a1ffd5..304c0e6c 100644 --- a/codes/models/srflow/Split.py +++ b/codes/models/image_generation/srflow/Split.py @@ -1,8 +1,8 @@ import torch from torch import nn as nn -from models.srflow import thops -from models.srflow.flow import Conv2dZeros, GaussianDiag +from models.image_generation.srflow import thops +from models.image_generation.srflow.flow import Conv2dZeros, GaussianDiag from utils.util import opt_get diff --git a/codes/models/glean/__init__.py b/codes/models/image_generation/srflow/__init__.py similarity index 100% rename from codes/models/glean/__init__.py rename to codes/models/image_generation/srflow/__init__.py diff --git a/codes/models/srflow/flow.py b/codes/models/image_generation/srflow/flow.py similarity index 98% rename from codes/models/srflow/flow.py rename to codes/models/image_generation/srflow/flow.py index d5c06277..db9ad7c5 100644 --- a/codes/models/srflow/flow.py +++ b/codes/models/image_generation/srflow/flow.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F import numpy as np -from models.srflow.FlowActNorms import ActNorm2d +from models.image_generation.srflow.FlowActNorms import ActNorm2d from . import thops diff --git a/codes/models/srflow/glow_arch.py b/codes/models/image_generation/srflow/glow_arch.py similarity index 100% rename from codes/models/srflow/glow_arch.py rename to codes/models/image_generation/srflow/glow_arch.py diff --git a/codes/models/srflow/module_util.py b/codes/models/image_generation/srflow/module_util.py similarity index 100% rename from codes/models/srflow/module_util.py rename to codes/models/image_generation/srflow/module_util.py diff --git a/codes/models/srflow/thops.py b/codes/models/image_generation/srflow/thops.py similarity index 100% rename from codes/models/srflow/thops.py rename to codes/models/image_generation/srflow/thops.py diff --git a/codes/models/stylegan/Discriminator_StyleGAN.py b/codes/models/image_generation/stylegan/Discriminator_StyleGAN.py similarity index 100% rename from codes/models/stylegan/Discriminator_StyleGAN.py rename to codes/models/image_generation/stylegan/Discriminator_StyleGAN.py diff --git a/codes/models/stylegan/__init__.py b/codes/models/image_generation/stylegan/__init__.py similarity index 66% rename from codes/models/stylegan/__init__.py rename to codes/models/image_generation/stylegan/__init__.py index e876d0c6..e219dc28 100644 --- a/codes/models/stylegan/__init__.py +++ b/codes/models/image_generation/stylegan/__init__.py @@ -2,10 +2,10 @@ def create_stylegan2_loss(opt_loss, env): type = opt_loss['type'] if type == 'stylegan2_divergence': - import models.stylegan.stylegan2_lucidrains as stylegan2 + import models.image_generation.stylegan.stylegan2_lucidrains as stylegan2 return stylegan2.StyleGan2DivergenceLoss(opt_loss, env) elif type == 'stylegan2_pathlen': - import models.stylegan.stylegan2_lucidrains as stylegan2 + import models.image_generation.stylegan.stylegan2_lucidrains as stylegan2 return stylegan2.StyleGan2PathLengthLoss(opt_loss, env) else: raise NotImplementedError \ No newline at end of file diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/image_generation/stylegan/stylegan2_lucidrains.py similarity index 100% rename from codes/models/stylegan/stylegan2_lucidrains.py rename to codes/models/image_generation/stylegan/stylegan2_lucidrains.py diff --git a/codes/models/stylegan/stylegan2_rosinality.py b/codes/models/image_generation/stylegan/stylegan2_rosinality.py similarity index 100% rename from codes/models/stylegan/stylegan2_rosinality.py rename to codes/models/image_generation/stylegan/stylegan2_rosinality.py diff --git a/codes/models/segformer/__init__.py b/codes/models/image_latents/__init__.py similarity index 100% rename from codes/models/segformer/__init__.py rename to codes/models/image_latents/__init__.py diff --git a/codes/models/srflow/__init__.py b/codes/models/image_latents/byol/__init__.py similarity index 100% rename from codes/models/srflow/__init__.py rename to codes/models/image_latents/byol/__init__.py diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/image_latents/byol/byol_model_wrapper.py similarity index 100% rename from codes/models/byol/byol_model_wrapper.py rename to codes/models/image_latents/byol/byol_model_wrapper.py diff --git a/codes/models/byol/byol_structural.py b/codes/models/image_latents/byol/byol_structural.py similarity index 98% rename from codes/models/byol/byol_structural.py rename to codes/models/image_latents/byol/byol_structural.py index 3aeb2602..778d8eeb 100644 --- a/codes/models/byol/byol_structural.py +++ b/codes/models/image_latents/byol/byol_structural.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch import nn from data.byol_attachment import reconstructed_shared_regions -from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \ +from models.image_latents.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \ update_moving_average from trainer.networks import create_model, register_model from utils.util import checkpoint diff --git a/codes/models/fixup_resnet/DiscriminatorResnet_arch.py b/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py similarity index 100% rename from codes/models/fixup_resnet/DiscriminatorResnet_arch.py rename to codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py diff --git a/codes/models/switched_conv/__init__.py b/codes/models/image_latents/fixup_resnet/__init__.py similarity index 100% rename from codes/models/switched_conv/__init__.py rename to codes/models/image_latents/fixup_resnet/__init__.py diff --git a/codes/models/spinenet_arch.py b/codes/models/image_latents/spinenet_arch.py similarity index 100% rename from codes/models/spinenet_arch.py rename to codes/models/image_latents/spinenet_arch.py diff --git a/codes/models/segformer/backbone.py b/codes/models/segformer/backbone.py deleted file mode 100644 index a8c4e55e..00000000 --- a/codes/models/segformer/backbone.py +++ /dev/null @@ -1,123 +0,0 @@ -# A direct copy of torchvision's resnet.py modified to support gradient checkpointing. - -import torch -import torch.nn as nn -from torchvision.models.resnet import BasicBlock, Bottleneck -import torchvision - - -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'wide_resnet50_2', 'wide_resnet101_2'] - -from trainer.networks import register_model -from utils.util import checkpoint - -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', - 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', - 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', - 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', -} - - -class Backbone(torchvision.models.resnet.ResNet): - - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None): - super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, - replace_stride_with_dilation, norm_layer) - del self.fc - del self.avgpool - - def _forward_impl(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - l1 = checkpoint(self.layer1, x) - l2 = checkpoint(self.layer2, l1) - l3 = checkpoint(self.layer3, l2) - l4 = checkpoint(self.layer4, l3) - - return l1, l2, l3, l4 - - def forward(self, x): - return self._forward_impl(x) - - -def _backbone(arch, block, layers, pretrained, progress, **kwargs): - model = Backbone(block, layers, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) - model.load_state_dict(state_dict) - return model - - -def backbone18(pretrained=False, progress=True, **kwargs): - r"""ResNet-18 model from - `"Deep Residual Learning for Image Recognition" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _backbone('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, - **kwargs) - - -def backbone34(pretrained=False, progress=True, **kwargs): - r"""ResNet-34 model from - `"Deep Residual Learning for Image Recognition" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _backbone('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - -def backbone50(pretrained=False, progress=True, **kwargs): - r"""ResNet-50 model from - `"Deep Residual Learning for Image Recognition" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _backbone('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - -def backbone101(pretrained=False, progress=True, **kwargs): - r"""ResNet-101 model from - `"Deep Residual Learning for Image Recognition" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _backbone('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, - **kwargs) - - -def backbone152(pretrained=False, progress=True, **kwargs): - r"""ResNet-152 model from - `"Deep Residual Learning for Image Recognition" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, - **kwargs) - diff --git a/codes/models/segformer/segformer.py b/codes/models/segformer/segformer.py deleted file mode 100644 index f9555405..00000000 --- a/codes/models/segformer/segformer.py +++ /dev/null @@ -1,131 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torchvision -from tqdm import tqdm -from models.segformer.backbone import backbone50 -from trainer.networks import register_model - - -# torch.gather() which operates as it always fucking should have: pulling indexes from the input. -def gather_2d(input, index): - b, c, h, w = input.shape - nodim = input.view(b, c, h * w) - ind_nd = index[:, 0]*w + index[:, 1] - ind_nd = ind_nd.unsqueeze(1) - ind_nd = ind_nd.repeat((1, c)) - ind_nd = ind_nd.unsqueeze(2) - result = torch.gather(nodim, dim=2, index=ind_nd) - result = result.squeeze() - if b == 1: - result = result.unsqueeze(0) - return result - - -class DilatorModule(nn.Module): - def __init__(self, input_channels, output_channels, max_dilation): - super().__init__() - self.max_dilation = max_dilation - self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=1, bias=True) - if max_dilation > 1: - self.bn = nn.BatchNorm2d(input_channels) - self.relu = nn.ReLU() - self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=max_dilation, dilation=max_dilation, bias=True) - self.dense = nn.Linear(input_channels, output_channels, bias=True) - - def forward(self, inp, loc): - x = self.conv1(inp) - if self.max_dilation > 1: - x = self.bn(self.relu(x)) - x = self.conv2(x) - - # This can be made more efficient by only computing these convolutions across a subset of the image. Possibly. - x = gather_2d(x, loc).contiguous() - return self.dense(x) - - -# Grabbed from torch examples: https://github.com/pytorch/examples/tree/master/https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65:7 -class PositionalEncoding(nn.Module): - def __init__(self, d_model, max_len=5000): - super(PositionalEncoding, self).__init__() - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer('pe', pe) - - def forward(self, x): - x = x + self.pe[:x.size(0), :] - return x - - -# Simple mean() layer encoded into a class so that BYOL can grab it. -class Tail(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.mean(dim=0) - - -class Segformer(nn.Module): - def __init__(self, latent_channels=1024, layers=8): - super().__init__() - self.backbone = backbone50() - backbone_channels = [256, 512, 1024, 2048] - dilations = [[1,2,3,4],[1,2,3],[1,2],[1]] - final_latent_channels = latent_channels - dilators = [] - for ic, dis in zip(backbone_channels, dilations): - layer_dilators = [] - for di in dis: - layer_dilators.append(DilatorModule(ic, final_latent_channels, di)) - dilators.append(nn.ModuleList(layer_dilators)) - self.dilators = nn.ModuleList(dilators) - - self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10) - self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)]) - self.tail = Tail() - - def forward(self, img=None, layers=None, pos=None, return_layers=False): - assert img is not None or layers is not None - if img is not None: - bs = img.shape[0] - layers = self.backbone(img) - else: - bs = layers[0].shape[0] - if return_layers: - return layers - - # A single position can be optionally given, in which case we need to expand it to represent the entire input. - if pos.shape == (2,): - pos = pos.unsqueeze(0).repeat(bs, 1) - - set = [] - pos = pos // 4 - for layer_out, dilator in zip(layers, self.dilators): - for subdilator in dilator: - set.append(subdilator(layer_out, pos)) - pos = pos // 2 - - # The torch transformer expects the set dimension to be 0. - set = torch.stack(set, dim=0) - set = self.token_position_encoder(set) - set = self.transformer_layers(set) - return self.tail(set) - - -@register_model -def register_segformer(opt_net, opt): - return Segformer() - - -if __name__ == '__main__': - model = Segformer().to('cuda') - for j in tqdm(range(1000)): - test_tensor = torch.randn(64,3,224,224).cuda() - print(model(img=test_tensor, pos=torch.randint(0,224,(64,2)).cuda()).shape) \ No newline at end of file diff --git a/codes/models/switched_conv/mixture_of_experts.py b/codes/models/switched_conv/mixture_of_experts.py deleted file mode 100644 index bf888a4f..00000000 --- a/codes/models/switched_conv/mixture_of_experts.py +++ /dev/null @@ -1,128 +0,0 @@ -# Contains implementations from the Mixture of Experts paper and Switch Transformers - - -# Implements KeepTopK where k=1 from mixture of experts paper. -import torch -import torch.nn as nn - -from models.switched_conv.switched_conv_hard_routing import RouteTop1 -from trainer.losses import ConfigurableLoss - - -class KeepTop1(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2) - input[mask != 1] = -float('inf') - ctx.save_for_backward(mask) - return input - - @staticmethod - def backward(ctx, grad): - import pydevd - pydevd.settrace(suspend=False, trace_only_current_thread=True) - mask = ctx.saved_tensors - grad_input = grad.clone() - grad_input[mask != 1] = 0 - return grad_input - -class MixtureOfExperts2dRouter(nn.Module): - def __init__(self, num_experts): - super().__init__() - self.num_experts = num_experts - self.wnoise = nn.Parameter(torch.zeros(1, num_experts, 1, 1)) - self.wg = nn.Parameter(torch.zeros(1, num_experts, 1, 1)) - - def forward(self, x): - wg = x * self.wg - wnoise = nn.functional.softplus(x * self.wnoise) - H = wg + torch.randn_like(x) * wnoise - - # Produce the load-balancing loss. - eye = torch.eye(self.num_experts, device=x.device).view(1, self.num_experts, self.num_experts, 1, 1) - mask = torch.abs(1 - eye) - b, c, h, w = H.shape - ninf = torch.zeros_like(eye) - ninf[eye == 1] = -float('inf') - H_masked = H.view(b, c, 1, h, - w) * mask + ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes. - max_excluding = torch.max(H_masked, dim=2)[0] - - # load_loss and G are stored as local members to facilitate their use by hard routing regularization losses. - # this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below. - self.load_loss = torch.erf((wg - max_excluding) / wnoise) - # self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so: - self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream. - - return self.G - - # Retrieve the locally stored loss values and delete them from membership (so as to not waste memory) - def reset(self): - G, load = self.G, self.load_loss - del self.G - del self.load_loss - return G, load - - -# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses. -class MixtureOfExpertsLoss(ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.routers = [] # This is filled in during the first forward() pass and cached from there. - self.first_forward_encountered = False - self.load_weight = opt['load_weight'] - self.importance_weight = opt['importance_weight'] - - def forward(self, net, state): - if not self.first_forward_encountered: - for m in net.modules(): - if isinstance(m, MixtureOfExperts2dRouter): - self.routers.append(m) - self.first_forward_encountered = True - - l_importance = 0 - l_load = 0 - for r in self.routers: - G, L = r.reset() - l_importance += G.var().square() - l_load += L.var().square() - return l_importance * self.importance_weight + l_load * self.load_weight - - -class SwitchTransformersLoadBalancer(nn.Module): - def __init__(self): - super().__init__() - self.norm = SwitchNorm(8, accumulator_size=256) - - def forward(self, x): - self.soft = self.norm(nn.functional.softmax(x, dim=1)) - self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream. - return self.hard - - def reset(self): - soft, hard = self.soft, self.hard - del self.soft, self.hard - return soft, hard - - -class SwitchTransformersLoadBalancingLoss(ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.routers = [] # This is filled in during the first forward() pass and cached from there. - self.first_forward_encountered = False - - def forward(self, net, state): - if not self.first_forward_encountered: - for m in net.modules(): - if isinstance(m, SwitchTransformersLoadBalancer): - self.routers.append(m) - self.first_forward_encountered = True - - loss = 0 - for r in self.routers: - soft, hard = r.reset() - N = hard.shape[1] - h_mean = hard.mean(dim=[0,2,3]) - s_mean = soft.mean(dim=[0,2,3]) - loss += torch.dot(h_mean, s_mean) * N - return loss \ No newline at end of file diff --git a/codes/models/switched_conv/switched_conv.py b/codes/models/switched_conv/switched_conv.py deleted file mode 100644 index 2e2677b7..00000000 --- a/codes/models/switched_conv/switched_conv.py +++ /dev/null @@ -1,135 +0,0 @@ -import functools -import math -from collections import OrderedDict - -import torch -import torch.nn as nn -from lambda_networks import LambdaLayer -from torch.nn import init, Conv2d -import torch.nn.functional as F - - -class SwitchedConv(nn.Module): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: int, - switch_breadth: int, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', - include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate. - coupler_mode: str = 'standard', - coupler_dim_in: int = 0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.padding_mode = padding_mode - self.groups = groups - - if include_coupler: - if coupler_mode == 'standard': - self.coupler = Conv2d(coupler_dim_in, switch_breadth, kernel_size=1) - elif coupler_mode == 'lambda': - self.coupler = LambdaLayer(dim=coupler_dim_in, dim_out=switch_breadth, r=23, dim_k=16, heads=2, dim_u=1) - - else: - self.coupler = None - - self.weights = nn.ParameterList([nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) for _ in range(switch_breadth)]) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter('bias', None) - self.reset_parameters() - - def reset_parameters(self) -> None: - for w in self.weights: - init.kaiming_uniform_(w, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weights[0]) - bound = 1 / math.sqrt(fan_in) - init.uniform_(self.bias, -bound, bound) - - def forward(self, inp, selector=None): - if self.coupler: - if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed. - selector = inp - selector = F.softmax(self.coupler(selector), dim=1) - self.last_select = selector.detach().clone() - out_shape = [s // self.stride for s in inp.shape[2:]] - if selector.shape[2] != out_shape[0] or selector.shape[3] != out_shape[1]: - selector = F.interpolate(selector, size=out_shape, mode="nearest") - assert selector is not None - - conv_results = [] - for i, w in enumerate(self.weights): - conv_results.append(F.conv2d(inp, w, self.bias, self.stride, self.padding, self.dilation, self.groups) * selector[:, i].unsqueeze(1)) - return torch.stack(conv_results, dim=-1).sum(dim=-1) - - - -# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them -# with the equivalent SwitchedConv.weight parameters. Does not create coupler params. -def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_list=[]): - state_dict = module.state_dict() - for name, m in module.named_modules(): - ignored = False - for smod in ignore_list: - if smod in name: - ignored = True - continue - if ignored: - continue - if isinstance(m, nn.Conv2d): - if name == '': - basename = 'weight' - modname = 'weights' - else: - basename = f'{name}.weight' - modname = f'{name}.weights' - cnv_weights = state_dict[basename] - del state_dict[basename] - for j in range(switch_breadth): - state_dict[f'{modname}.{j}'] = cnv_weights.clone() - return state_dict - - -def test_net(): - base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda') - mod_conv = SwitchedConv(32, 64, 3, switch_breadth=8, stride=2, padding=1, bias=True, include_coupler=True, coupler_dim_in=128).to('cuda') - mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8) - mod_conv.load_state_dict(mod_sd, strict=False) - inp = torch.randn((8,32,128,128), device='cuda') - sel = torch.randn((8,128,32,32), device='cuda') - out1 = base_conv(inp) - out2 = mod_conv(inp, sel) - assert(torch.max(torch.abs(out1-out2)) < 1e-6) - -def perform_conversion(): - sd = torch.load("../experiments/rrdb_imgset_226500_generator.pth") - load_net_clean = OrderedDict() # remove unnecessary 'module.' - for k, v in sd.items(): - if k.startswith('module.'): - load_net_clean[k.replace('module.', '')] = v - else: - load_net_clean[k] = v - sd = load_net_clean - import models.RRDBNet_arch as rrdb - block = functools.partial(rrdb.RRDBWithBypass) - mod = rrdb.RRDBNet(in_channels=3, out_channels=3, - mid_channels=64, num_blocks=23, body_block=block, scale=2, initial_stride=2) - mod.load_state_dict(sd) - converted = convert_conv_net_state_dict_to_switched_conv(mod, 8, ['body.','conv_first','resnet_encoder']) - torch.save(converted, "../experiments/rrdb_imgset_226500_generator_converted.pth") - - -if __name__ == '__main__': - perform_conversion() diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py deleted file mode 100644 index 28342f93..00000000 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ /dev/null @@ -1,360 +0,0 @@ -import math - -import torch -import torch.nn as nn -from lambda_networks import LambdaLayer -from torch.nn import init, Conv2d, MSELoss, ZeroPad2d -import torch.nn.functional as F -from tqdm import tqdm -import torch.distributed as dist - -from trainer.losses import ConfigurableLoss - - -def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1): - convs = [] - b, s, h, w = selector.shape - for sel in range(s): - convs.append(F.conv2d(input, weight[:, :, sel, :, :], bias, stride=stride, padding=weight.shape[-1] // 2)) - output = torch.stack(convs, dim=1) * selector.unsqueeze(dim=2) - return output.sum(dim=1) - - -class SwitchedConvHardRoutingFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, selector, weight, bias, stride=1): - # Pre-pad the input. - input = ZeroPad2d(weight.shape[-1]//2)(input) - - # Build hard attention mask from selector input - b, s, h, w = selector.shape - - mask = selector.argmax(dim=1).int() - import switched_conv_cuda_naive - output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride) - - ctx.stride = stride - ctx.breadth = s - ctx.save_for_backward(*[input, output.detach().clone(), mask, weight, bias]) - return output - - @staticmethod - def backward(ctx, gradIn): - #import pydevd - #pydevd.settrace(suspend=False, trace_only_current_thread=True) - input, output, mask, weight, bias = ctx.saved_tensors - gradIn = gradIn - - # Selector grad is simply the element-wise product of grad with the output of the layer, summed across the channel dimension - # and repeated along the breadth of the switch. (Think of the forward operation using the selector as a simple matrix of 1s - # and zeros that is multiplied by the output.) - grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1) - - import switched_conv_cuda_naive - grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride) - - # Remove input padding from grad - padding = weight.shape[-1] // 2 - if padding > 0: - grad = grad[:,:,padding:-padding,padding:-padding] - return grad, grad_sel, grad_w, grad_b, None - - -class RouteTop1(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]) - if len(input.shape) > 2: - mask = mask.permute(0, 3, 1, 2) # TODO: Make this more extensible. - out = torch.ones_like(input) - out[mask != 1] = 0 - ctx.save_for_backward(mask, input.clone()) - return out - - @staticmethod - def backward(ctx, grad): - # Enable breakpoints in this function: (Comment out if not debugging) - #import pydevd - #pydevd.settrace(suspend=False, trace_only_current_thread=True) - - mask, input = ctx.saved_tensors - input[mask != 1] = 1 - grad_input = grad.clone() - grad_input[mask != 1] = 0 - grad_input_n = grad_input / input # Above, we made everything either a zero or a one. Unscale the ones by dividing by the unmasked inputs. - return grad_input_n - - -""" -SwitchNorm is meant to be applied against the Softmax output of a switching function across a large set of -switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude -of switch weights that are over-used and increasing the magnitude of under-used weights. - -The return value has the exact same format as a normal Softmax output and can be used directly into the input of an -switch equation. - -Since the whole point of convolutional switch is to enable training extra-wide networks to operate on a large number -of image categories, it makes almost no sense to perform this type of norm against a single mini-batch of images: some -of the switches will not be used in such a small context - and that's good! This is solved by accumulating. Every -forward pass computes a norm across the current minibatch. That norm is added into a rotating buffer of size -. The actual normalization occurs across the entire rotating buffer. - -You should set accumulator size according to two factors: -- Your batch size. Smaller batch size should mean greater accumulator size. -- Your image diversity. More diverse images have less need for the accumulator. -- How wide your switch/switching group size is. More groups mean you're going to want more accumulation. - -Note: This norm makes the (potentially flawed) assumption that each forward() pass has unique data. For maximum - effectiveness, avoid doing this - or make alterations to work around it. -Note: This norm does nothing for the first iterations. -""" -class SwitchNorm(nn.Module): - def __init__(self, group_size, accumulator_size=128): - super().__init__() - self.accumulator_desired_size = accumulator_size - self.group_size = group_size - self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu')) - self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu')) - self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size)) - - def add_norm_to_buffer(self, x): - flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)] - flat = x.sum(dim=flatten_dims) - norm = flat / torch.mean(flat) - - self.accumulator[self.accumulator_index] = norm.detach().clone() - self.accumulator_index += 1 - if self.accumulator_index >= self.accumulator_desired_size: - self.accumulator_index *= 0 - if self.accumulator_filled <= 0: - self.accumulator_filled += 1 - - # Input into forward is a switching tensor of shape (batch,groups,) - def forward(self, x: torch.Tensor, update_attention_norm=True): - assert len(x.shape) >= 2 - - # Push the accumulator to the right device on the first iteration. - if self.accumulator.device != x.device: - self.accumulator = self.accumulator.to(x.device) - - # In eval, don't change the norm buffer. - if self.training and update_attention_norm: - self.add_norm_to_buffer(x) - - # Reduce across all distributed entities, if needed - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(self.accumulator, op=dist.ReduceOp.SUM) - self.accumulator /= dist.get_world_size() - - # Compute the norm factor. - if self.accumulator_filled > 0: - norm = torch.mean(self.accumulator, dim=0) - norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here. - else: - norm = torch.ones(self.group_size, device=self.accumulator.device) - - norm = norm.view(1,-1) - while len(x.shape) > len(norm.shape): - norm = norm.unsqueeze(-1) - x = x / norm - - return x - - -class HardRoutingGate(nn.Module): - def __init__(self, breadth, hard_en=True): - super().__init__() - self.norm = SwitchNorm(breadth, accumulator_size=256) - self.hard_en = hard_en - - def forward(self, x): - soft = self.norm(nn.functional.softmax(x, dim=1)) - if self.hard_en: - return RouteTop1.apply(soft) - return soft - - -class SwitchedConvHardRouting(nn.Module): - def __init__(self, - in_c, - out_c, - kernel_sz, - breadth, - stride=1, - bias=True, - dropout_rate=0.0, - include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate. - coupler_mode: str = 'standard', - coupler_dim_in: int = 0, - hard_en=True, # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention. - emulate_swconv=True, # When set, performs a nn.Conv2d operation for each breadth. When false, uses the native cuda implementation which computes all switches concurrently. - ): - super().__init__() - self.in_channels = in_c - self.out_channels = out_c - self.kernel_size = kernel_sz - self.stride = stride - self.has_bias = bias - self.breadth = breadth - self.dropout_rate = dropout_rate - - if include_coupler: - if coupler_mode == 'standard': - self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride) - elif coupler_mode == 'lambda': - self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), - nn.BatchNorm2d(coupler_dim_in), - nn.ReLU(), - LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), - nn.BatchNorm2d(breadth), - nn.ReLU(), - Conv2d(breadth, breadth, 1, stride=self.stride)) - elif coupler_mode == 'lambda2': - self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), - nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in), - nn.ReLU(), - LambdaLayer(dim=coupler_dim_in, dim_out=coupler_dim_in, r=23, dim_k=16, heads=2, dim_u=1), - nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in), - nn.ReLU(), - LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), - nn.GroupNorm(num_groups=1, num_channels=breadth), - nn.ReLU(), - Conv2d(breadth, breadth, 1, stride=self.stride)) - else: - self.coupler = None - self.gate = HardRoutingGate(breadth, hard_en=True) - self.hard_en = hard_en - - self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz)) - if bias: - self.bias = nn.Parameter(torch.empty(out_c)) - else: - self.bias = torch.zeros(out_c) - self.reset_parameters() - - def reset_parameters(self) -> None: - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight[:,:,0,:,:]) - bound = 1 / math.sqrt(fan_in) - init.uniform_(self.bias, -bound, bound) - - def load_weights_from_conv(self, cnv): - sd = cnv.state_dict() - sd['weight'] = sd['weight'].unsqueeze(2).repeat(1,1,self.breadth,1,1) - self.load_state_dict(sd) - - def forward(self, input, selector=None): - if self.bias.device != input.device: - self.bias = self.bias.to(input.device) # Because this bias can be a tensor that is not moved with the rest of the module. - - # If a coupler was specified, run that to convert selector into a softmax distribution. - if self.coupler: - if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed. - selector = input - selector = self.coupler(selector) - assert selector is not None - - # Apply dropout at the batch level per kernel. - if self.training and self.dropout_rate > 0: - b, c, h, w = selector.shape - drop = torch.rand((b, c, 1, 1), device=input.device) > self.dropout_rate - # Ensure that there is always at least one switch left un-dropped out - fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, c, 1, 1) - drop = drop.logical_or(fix_blank) - selector = drop * selector - - selector = self.gate(selector) - - # Debugging variables - self.last_select = selector.detach().clone() - self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1) - - if self.hard_en: - # This is a custom CUDA implementation which should be faster and less memory intensive (once completed). - return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride) - else: - # This composes the switching functionality using raw Torch, which basically consists of computing each of convs separately and combining them. - return SwitchedConvRoutingNormal(input, selector, self.weight, self.bias, self.stride) - - -# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them -# with the equivalent SwitchedConv.weight parameters. Does not create coupler params. -def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_list=[]): - state_dict = module.state_dict() - for name, m in module.named_modules(): - if not isinstance(m, nn.Conv2d): - continue - ignored = False - for smod in ignore_list: - if smod in name: - ignored = True - continue - if ignored: - continue - if name == '': - key = 'weight' - else: - key = f'{name}.weight' - state_dict[key] = state_dict[key].unsqueeze(2).repeat(1,1,switch_breadth,1,1) - return state_dict - - -# Given a state_dict and the module that that sd belongs to, strips out the specified Conv2d modules and replaces them -# with equivalent switched_conv modules. -def convert_net_to_switched_conv(module, switch_breadth, allow_list, dropout_rate=0.4, coupler_mode='lambda'): - print("CONVERTING MODEL TO SWITCHED_CONV MODE") - - # Next, convert the model itself: - full_paths = [n.split('.') for n in allow_list] - for modpath in full_paths: - mod = module - for sub in modpath[:-1]: - pmod = mod - mod = getattr(mod, sub) - old_conv = getattr(mod, modpath[-1]) - new_conv = SwitchedConvHardRouting('.'.join(modpath), old_conv.in_channels, old_conv.out_channels, old_conv.kernel_size[0], switch_breadth, old_conv.stride[0], old_conv.bias, - include_coupler=True, dropout_rate=dropout_rate, coupler_mode=coupler_mode) - new_conv = new_conv.to(old_conv.weight.device) - assert old_conv.dilation == 1 or old_conv.dilation == (1,1) or old_conv.dilation is None - if isinstance(mod, nn.Sequential): - # If we use the standard logic (in the else case) here, it reorders the sequential. - # Instead, extract the OrderedDict from the current sequential, replace the Conv inside that dict, then replace the entire sequential to keep the order. - emods = mod._modules - emods[modpath[-1]] = new_conv - delattr(pmod, modpath[-2]) - pmod.add_module(modpath[-2], nn.Sequential(emods)) - else: - delattr(mod, modpath[-1]) - mod.add_module(modpath[-1], new_conv) - - -def convert_state_dict_to_switched_conv(sd_file, switch_breadth, allow_list): - save = torch.load(sd_file) - sd = save['state_dict'] - converted = 0 - for cname in allow_list: - for sn in sd.keys(): - if cname in sn and sn.endswith('weight'): - sd[sn] = sd[sn].unsqueeze(2).repeat(1,1,switch_breadth,1,1) - converted += 1 - print(f"Converted {converted} parameters.") - torch.save(save, sd_file.replace('.pt', "_converted.pt")) - - -def test_net(): - for j in tqdm(range(100)): - base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda') - mod_conv = SwitchedConvHardRouting(32, 64, 3, breadth=8, stride=2, bias=True, include_coupler=True, coupler_dim_in=32, dropout_rate=.2).to('cuda') - mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8) - mod_conv.load_state_dict(mod_sd, strict=False) - inp = torch.randn((128, 32, 128, 128), device='cuda') - out1 = base_conv(inp) - out2 = mod_conv(inp, None) - compare = (out2+torch.rand_like(out2)*1e-6).detach() - MSELoss()(out2, compare).backward() - assert(torch.max(torch.abs(out1-out2)) < 1e-5) - - -if __name__ == '__main__': - test_net() \ No newline at end of file diff --git a/codes/scripts/audio/test_audio_similarity.py b/codes/scripts/audio/test_audio_similarity.py index a68332e2..4f513afb 100644 --- a/codes/scripts/audio/test_audio_similarity.py +++ b/codes/scripts/audio/test_audio_similarity.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F from data.util import is_wav_file, find_files_of_type -from models.audio_resnet import resnet50 +from models.audio.audio_resnet import resnet50 from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict diff --git a/codes/scripts/byol/byol_extract_wrapped_model.py b/codes/scripts/byol/byol_extract_wrapped_model.py index e002870c..6586c053 100644 --- a/codes/scripts/byol/byol_extract_wrapped_model.py +++ b/codes/scripts/byol/byol_extract_wrapped_model.py @@ -1,6 +1,5 @@ import torch -from models.spinenet_arch import SpineNet def extract_byol_model_from_state_dict(sd): wrap_key = 'online_encoder.net.' diff --git a/codes/scripts/byol/byol_spinenet_playground.py b/codes/scripts/byol/byol_spinenet_playground.py index 31d0c621..d646ba57 100644 --- a/codes/scripts/byol/byol_spinenet_playground.py +++ b/codes/scripts/byol/byol_spinenet_playground.py @@ -1,5 +1,4 @@ import os -import shutil import torch import torch.nn as nn @@ -11,14 +10,12 @@ from torchvision.transforms import ToTensor, Resize from tqdm import tqdm import numpy as np -import utils from data.image_folder_dataset import ImageFolderDataset -from models.spinenet_arch import SpineNet +from models.image_latents.spinenet_arch import SpineNet # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # and the distance is computed across the channel dimension. -from utils import util from utils.options import dict_to_nonedict diff --git a/codes/scripts/stylegan2/convert_weights_rosinality.py b/codes/scripts/stylegan2/convert_weights_rosinality.py index d0b505c7..499a0f91 100644 --- a/codes/scripts/stylegan2/convert_weights_rosinality.py +++ b/codes/scripts/stylegan2/convert_weights_rosinality.py @@ -13,7 +13,7 @@ import torch import numpy as np from torchvision import utils -from models.stylegan.stylegan2_rosinality import Generator, Discriminator +from models.image_generation.stylegan.stylegan2_rosinality import Generator, Discriminator # Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter. @@ -237,7 +237,6 @@ if __name__ == "__main__": args = parser.parse_args() sys.path.append('scripts\\stylegan2') - import dnnlib from dnnlib.tflib.network import generator, discriminator, gen_ema with open(args.path, "rb") as f: diff --git a/codes/trainer/custom_training_components/tecogan_losses.py b/codes/trainer/custom_training_components/tecogan_losses.py index 31154138..2d9c0315 100644 --- a/codes/trainer/custom_training_components/tecogan_losses.py +++ b/codes/trainer/custom_training_components/tecogan_losses.py @@ -1,6 +1,6 @@ from torch.cuda.amp import autocast -from models.stylegan.stylegan2_lucidrains import gradient_penalty +from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.flownet2.networks import Resample2d from trainer.inject import Injector diff --git a/codes/trainer/eval/flow_gaussian_nll.py b/codes/trainer/eval/flow_gaussian_nll.py index d55ffd7c..237d40e8 100644 --- a/codes/trainer/eval/flow_gaussian_nll.py +++ b/codes/trainer/eval/flow_gaussian_nll.py @@ -6,7 +6,7 @@ import trainer.eval.evaluator as evaluator # Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair. from data.image_folder_dataset import ImageFolderDataset -from models.srflow.flow import GaussianDiag +from models.image_generation.srflow import GaussianDiag class FlowGaussianNll(evaluator.Evaluator): diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index 0352fc52..fadc0bd0 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -16,13 +16,13 @@ def create_loss(opt_loss, env): from trainer.custom_training_components import create_teco_loss return create_teco_loss(opt_loss, env) elif 'stylegan2_' in type: - from models.stylegan import create_stylegan2_loss + from models.image_generation.stylegan import create_stylegan2_loss return create_stylegan2_loss(opt_loss, env) elif 'style_sr_' in type: from models.styled_sr import create_stylesr_loss return create_stylesr_loss(opt_loss, env) elif 'lightweight_gan_divergence' == type: - from models.lightweight_gan import LightweightGanDivergenceLoss + from models.image_generation.lightweight_gan import LightweightGanDivergenceLoss return LightweightGanDivergenceLoss(opt_loss, env) elif type == 'crossentropy' or type == 'cross_entropy': return CrossEntropy(opt_loss, env) @@ -401,7 +401,7 @@ class DiscriminatorGanLoss(ConfigurableLoss): if self.gradient_penalty: # Apply gradient penalty. TODO: migrate this elsewhere. - from models.stylegan.stylegan2_lucidrains import gradient_penalty + from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators. gp, gp_structure = gradient_penalty(real[0], d_real, return_structured_grads=True) self.metrics.append(("gradient_penalty", gp.clone().detach())) diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index f73a15c0..723389ff 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -5,8 +5,6 @@ import pkgutil import sys from collections import OrderedDict from inspect import isfunction, getmembers, signature -import torch -import models.feature_arch as feature_arch logger = logging.getLogger('base') diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py index 63cc763f..6949ce5a 100644 --- a/codes/utils/numeric_stability.py +++ b/codes/utils/numeric_stability.py @@ -1,7 +1,6 @@ import torch from torch import nn -import models.SwitchedResidualGenerator_arch as srg -import models.discriminator_vgg_arch as disc +import models.image_generation.discriminator_vgg_arch as disc import functools blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]