diff --git a/.idea/other.xml b/.idea/other.xml new file mode 100644 index 00000000..58daadce --- /dev/null +++ b/.idea/other.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/codes/models/styled_sr/discriminator.py b/codes/models/styled_sr/discriminator.py index 73ba6a83..44fd83f6 100644 --- a/codes/models/styled_sr/discriminator.py +++ b/codes/models/styled_sr/discriminator.py @@ -1,4 +1,6 @@ # Heavily based on the lucidrains stylegan2 discriminator implementation. +import math +import os from functools import partial from math import log2 from random import random @@ -6,31 +8,33 @@ from random import random import torch import torch.nn as nn import torch.nn.functional as F +import torchvision from torch.autograd import grad as torch_grad import trainer.losses as L from vector_quantize_pytorch import VectorQuantize from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists +from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear from trainer.networks import register_model from utils.util import checkpoint, opt_get class DiscriminatorBlock(nn.Module): - def __init__(self, input_channels, filters, downsample=True): + def __init__(self, input_channels, filters, downsample=True, transfer_mode=False): super().__init__() self.filters = filters - self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) + self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode) self.net = nn.Sequential( - nn.Conv2d(input_channels, filters, 3, padding=1), + TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode), leaky_relu(), - nn.Conv2d(filters, filters, 3, padding=1), + TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode), leaky_relu() ) self.downsample = nn.Sequential( Blur(), - nn.Conv2d(filters, filters, 3, padding=1, stride=2) + TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode) ) if downsample else None def forward(self, x): @@ -42,9 +46,10 @@ class DiscriminatorBlock(nn.Module): return x -class StyleGan2Discriminator(nn.Module): +class StyleSrDiscriminator(nn.Module): def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], - transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False): + transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False, + transfer_mode=False): super().__init__() num_layers = int(log2(image_size) - 1) @@ -63,7 +68,7 @@ class StyleGan2Discriminator(nn.Module): num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) - block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) + block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode) blocks.append(block) attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None @@ -84,17 +89,23 @@ class StyleGan2Discriminator(nn.Module): chan_last = filters[-1] latent_dim = 2 * 2 * chan_last - self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) + self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode) self.flatten = nn.Flatten() if mlp: - self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100), + self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode), leaky_relu(), - nn.Linear(100, 1)) + TransferLinear(100, 1, transfer_mode=transfer_mode)) else: - self.to_logit = nn.Linear(latent_dim, 1) + self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode) self._init_weights() + self.transfer_mode = transfer_mode + if transfer_mode: + for p in self.parameters(): + if not hasattr(p, 'FOR_TRANSFER_LEARNING'): + p.DO_NOT_TRAIN = True + def forward(self, x): b, *_ = x.shape @@ -123,12 +134,12 @@ class StyleGan2Discriminator(nn.Module): def _init_weights(self): for m in self.modules(): - if type(m) in {nn.Conv2d, nn.Linear}: + if type(m) in {TransferConv2d, TransferLinear}: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') # Configures the network as partially pre-trained. This means: # 1) The top (high-resolution) `num_blocks` will have their weights re-initialized. - # 2) The haed (linear layers) will also have their weights re-initialized + # 2) The head (linear layers) will also have their weights re-initialized # 3) All intermediate blocks will be frozen until step `frozen_until_step` # These settings will be applied after the weights have been loaded (network_loaded()) def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0): @@ -150,7 +161,7 @@ class StyleGan2Discriminator(nn.Module): reset_blocks.append(self.blocks[i]) for bl in reset_blocks: for m in bl.modules(): - if type(m) in {nn.Conv2d, nn.Linear}: + if type(m) in {TransferConv2d, TransferLinear}: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') for p in m.parameters(recurse=True): p._NEW_BLOCK = True @@ -237,13 +248,15 @@ class DiscAugmentor(nn.Module): self.prob = prob self.types = types - def forward(self, images, detach=False): + def forward(self, images, real_images=False): if random() < self.prob: images = random_hflip(images, prob=0.5) images = DiffAugment(images, types=self.types) - if detach: - images = images.detach() + if real_images: + self.hq_aug = images.detach().clone() + else: + self.gen_aug = images.detach().clone() # Save away for use elsewhere (e.g. unet loss) self.aug_images = images @@ -253,6 +266,11 @@ class DiscAugmentor(nn.Module): def network_loaded(self): self.D.network_loaded() + # Allows visualizing what the augmentor is up to. + def visual_dbg(self, step, path): + torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step))) + torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step))) + def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): if fp16: @@ -294,12 +312,12 @@ class StyleSrGanDivergenceLoss(L.ConfigurableLoss): real_input = real_input + torch.rand_like(real_input) * self.noise D = self.env['discriminators'][self.discriminator] - fake = D(fake_input) + fake = D(fake_input, real_images=False) if self.for_gen: return fake.mean() else: real_input.requires_grad_() # <-- Needed to compute gradients on the input. - real = D(real_input) + real = D(real_input, real_images=True) divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() # Apply gradient penalty. TODO: migrate this elsewhere. @@ -315,10 +333,12 @@ class StyleSrGanDivergenceLoss(L.ConfigurableLoss): @register_model def register_styledsr_discriminator(opt_net, opt): attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] - disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, - do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), - quantize=opt_get(opt_net, ['quantize'], False), - mlp=opt_get(opt_net, ['mlp_head'], True)) + disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, + do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), + quantize=opt_get(opt_net, ['quantize'], False), + mlp=opt_get(opt_net, ['mlp_head'], True), + transfer_mode=opt_get(opt_net, ['transfer_mode'], False) + ) if 'use_partial_pretrained' in opt_net.keys(): disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until']) return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) diff --git a/codes/models/styled_sr/styled_sr.py b/codes/models/styled_sr/styled_sr.py index f11d7b78..c40bb00b 100644 --- a/codes/models/styled_sr/styled_sr.py +++ b/codes/models/styled_sr/styled_sr.py @@ -1,29 +1,37 @@ -from math import log2 from random import random import torch import torch.nn as nn import torch.nn.functional as F -from models.RRDBNet_arch import RRDB -from models.arch_util import ConvGnLelu, default_init_weights -from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur +from models.arch_util import kaiming_init +from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock +from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear from trainer.networks import register_model from utils.util import checkpoint, opt_get +def rrdb_init_weights(module, scale=1): + for m in module.modules(): + if isinstance(m, TransferConv2d): + kaiming_init(m, a=0, mode='fan_in', bias=0) + m.weight.data *= scale + elif isinstance(m, TransferLinear): + kaiming_init(m, a=0, mode='fan_in', bias=0) + m.weight.data *= scale + + class EncoderRRDB(nn.Module): - def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1): + def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False): super(EncoderRRDB, self).__init__() for i in range(5): out_channels = output_channels if i == 4 else growth_channels self.add_module( f'conv{i+1}', - nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3, - 1, 1)) + TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode)) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) for i in range(5): - default_init_weights(getattr(self, f'conv{i+1}'), init_weight) + rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight) def forward(self, x): x1 = self.lrelu(self.conv1(x)) @@ -35,18 +43,18 @@ class EncoderRRDB(nn.Module): class StyledSrEncoder(nn.Module): - def __init__(self, fea_out=256, initial_stride=1): + def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False): super().__init__() # Current assumes fea_out=256. - self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True) + self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode) self.rrdbs = nn.ModuleList([ - EncoderRRDB(32), - EncoderRRDB(64), - EncoderRRDB(96), - EncoderRRDB(128), - EncoderRRDB(160), - EncoderRRDB(192), - EncoderRRDB(224)]) + EncoderRRDB(32, transfer_mode=transfer_mode), + EncoderRRDB(64, transfer_mode=transfer_mode), + EncoderRRDB(96, transfer_mode=transfer_mode), + EncoderRRDB(128, transfer_mode=transfer_mode), + EncoderRRDB(160, transfer_mode=transfer_mode), + EncoderRRDB(192, transfer_mode=transfer_mode), + EncoderRRDB(224, transfer_mode=transfer_mode)]) def forward(self, x): fea = self.initial_conv(x) @@ -56,13 +64,14 @@ class StyledSrEncoder(nn.Module): class Generator(nn.Module): - def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2): + def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False): super().__init__() total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample. self.image_size = image_size self.scale = 2 ** upsample_levels self.latent_dim = latent_dim self.num_layers = total_levels + self.transfer_mode = transfer_mode filters = [ 512, # 4x4 512, # 8x8 @@ -75,7 +84,8 @@ class Generator(nn.Module): 8, # 1024x1024 ] - self.encoder = StyledSrEncoder(filters[start_level], initial_stride) + # I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted. + self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False) in_out_pairs = list(zip(filters[:-1], filters[1:])) self.blocks = nn.ModuleList([]) @@ -88,13 +98,18 @@ class Generator(nn.Module): in_chan, out_chan, upsample=not_first, - upsample_rgb=not_last + upsample_rgb=not_last, + transfer_learning_mode=transfer_mode ) self.blocks.append(block) def forward(self, lr, styles): b, c, h, w = lr.shape - x = self.encoder(lr) + if self.transfer_mode: + with torch.no_grad(): + x = self.encoder(lr) + else: + x = self.encoder(lr) styles = styles.transpose(0, 1) input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device) @@ -102,6 +117,7 @@ class Generator(nn.Module): rgb = F.interpolate(lr, size=x.shape[2:], mode="area") else: rgb = lr + for style, block in zip(styles, self.blocks): x, rgb = checkpoint(block, x, rgb, style, input_noise) @@ -109,16 +125,23 @@ class Generator(nn.Module): class StyledSrGenerator(nn.Module): - def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1): + def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False): super().__init__() - self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) - self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride) + # Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later. + self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False) + self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode) self.mixed_prob = .9 self._init_weights() + self.transfer_mode = transfer_mode + if transfer_mode: + for p in self.parameters(): + if not hasattr(p, 'FOR_TRANSFER_LEARNING'): + p.DO_NOT_TRAIN = True + def _init_weights(self): for m in self.modules(): - if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'): + if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') for block in self.gen.blocks: @@ -132,7 +155,11 @@ class StyledSrGenerator(nn.Module): # Synthesize style latents from noise. style = torch.randn(b*2, self.gen.latent_dim).to(x.device) - w = self.vectorizer(style) + if self.transfer_mode: + with torch.no_grad(): + w = self.vectorizer(style) + else: + w = self.vectorizer(style) # Randomly distribute styles across layers w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() @@ -162,4 +189,6 @@ if __name__ == '__main__': @register_model def register_styled_sr(opt_net, opt): - return StyledSrGenerator(128, initial_stride=opt_get(opt_net, ['initial_stride'], 1)) + return StyledSrGenerator(128, + initial_stride=opt_get(opt_net, ['initial_stride'], 1), + transfer_mode=opt_get(opt_net, ['transfer_mode'], False)) diff --git a/codes/models/styled_sr/stylegan2_base.py b/codes/models/styled_sr/stylegan2_base.py index a0285895..2738725a 100644 --- a/codes/models/styled_sr/stylegan2_base.py +++ b/codes/models/styled_sr/stylegan2_base.py @@ -6,8 +6,12 @@ import torch import torch.nn.functional as F from kornia.filters import filter2D from linear_attention_transformer import ImageLinearAttention -from torch import nn +from torch import nn, Tensor from torch.autograd import grad as torch_grad +from torch.nn import Parameter, init +from torch.nn.modules.conv import _ConvNd + +from models.styled_sr.transfer_primitives import TransferLinear assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' @@ -196,10 +200,9 @@ def slerp(val, low, high): res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res -# stylegan2 classes class EqualLinear(nn.Module): - def __init__(self, in_dim, out_dim, lr_mul=1, bias=True): + def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) if bias: @@ -207,17 +210,28 @@ class EqualLinear(nn.Module): self.lr_mul = lr_mul + self.transfer_mode = transfer_mode + if transfer_mode: + self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features)) + self.transfer_scale.FOR_TRANSFER_LEARNING = True + self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features)) + self.transfer_shift.FOR_TRANSFER_LEARNING = True + def forward(self, input): - return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) + if self.transfer_mode: + weight = self.weight * self.transfer_scale + self.transfer_shift + else: + weight = self.weight + return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul) class StyleVectorizer(nn.Module): - def __init__(self, emb, depth, lr_mul=0.1): + def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False): super().__init__() layers = [] for i in range(depth): - layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()]) + layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()]) self.net = nn.Sequential(*layers) @@ -227,13 +241,13 @@ class StyleVectorizer(nn.Module): class RGBBlock(nn.Module): - def __init__(self, latent_dim, input_channel, upsample, rgba=False): + def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False): super().__init__() self.input_channel = input_channel self.to_style = nn.Linear(latent_dim, input_channel) out_filters = 3 if not rgba else 4 - self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) + self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode) self.upsample = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), @@ -307,25 +321,11 @@ class EqualLR: def equal_lr(module, name='weight'): EqualLR.apply(module, name) - return module -class EqualConv2d(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - conv = nn.Conv2d(*args, **kwargs) - conv.weight.data.normal_() - conv.bias.data.zero_() - self.conv = equal_lr(conv) - - def forward(self, input): - return self.conv(input) - - class Conv2DMod(nn.Module): - def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs): + def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs): super().__init__() self.filters = out_chan self.demod = demod @@ -334,6 +334,12 @@ class Conv2DMod(nn.Module): self.dilation = dilation self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + self.transfer_mode = transfer_mode + if transfer_mode: + self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1)) + self.transfer_scale.FOR_TRANSFER_LEARNING = True + self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1)) + self.transfer_shift.FOR_TRANSFER_LEARNING = True def _get_same_padding(self, size, kernel, dilation, stride): return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 @@ -341,8 +347,13 @@ class Conv2DMod(nn.Module): def forward(self, x, y): b, c, h, w = x.shape + if self.transfer_mode: + weight = self.weight * self.transfer_scale + self.transfer_shift + else: + weight = self.weight + w1 = y[:, None, :, None, None] - w2 = self.weight[None, :, :, :, :] + w2 = weight[None, :, :, :, :] weights = w2 * (w1 + 1) if self.demod: @@ -362,34 +373,28 @@ class Conv2DMod(nn.Module): class GeneratorBlock(nn.Module): - def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False): + def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, + transfer_learning_mode=False): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None - self.structure_input = structure_input - if self.structure_input: - self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1) - input_channels = input_channels * 2 + self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode) + self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode) + self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode) - self.to_style1 = nn.Linear(latent_dim, input_channels) - self.to_noise1 = nn.Linear(1, filters) - self.conv1 = Conv2DMod(input_channels, filters, 3) - - self.to_style2 = nn.Linear(latent_dim, filters) - self.to_noise2 = nn.Linear(1, filters) - self.conv2 = Conv2DMod(filters, filters, 3) + self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode) + self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode) + self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode) self.activation = leaky_relu() - self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) + self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode) - def forward(self, x, prev_rgb, istyle, inoise, structure_input=None): + self.transfer_learning_mode = transfer_learning_mode + + def forward(self, x, prev_rgb, istyle, inoise): if exists(self.upsample): x = self.upsample(x) - if self.structure_input: - s = self.structure_conv(structure_input) - x = torch.cat([x, s], dim=1) - inoise = inoise[:, :x.shape[2], :x.shape[3], :] noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) diff --git a/codes/models/styled_sr/transfer_primitives.py b/codes/models/styled_sr/transfer_primitives.py new file mode 100644 index 00000000..93af5391 --- /dev/null +++ b/codes/models/styled_sr/transfer_primitives.py @@ -0,0 +1,136 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter, init +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.utils import _ntuple + +_pair = _ntuple(2) + +class TransferConv2d(_ConvNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride = 1, + padding = 0, + dilation = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + transfer_mode: bool = False + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _pair(0), groups, bias, padding_mode) + + self.transfer_mode = transfer_mode + if transfer_mode: + self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1)) + self.transfer_scale.FOR_TRANSFER_LEARNING = True + self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1)) + self.transfer_shift.FOR_TRANSFER_LEARNING = True + + def _conv_forward(self, input, weight): + if self.transfer_mode: + weight = weight * self.transfer_scale + self.transfer_shift + else: + weight = weight + + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.bias, self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight) + + +class TransferLinear(nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = Parameter(torch.Tensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.transfer_mode = transfer_mode + if transfer_mode: + self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features)) + self.transfer_scale.FOR_TRANSFER_LEARNING = True + self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features)) + self.transfer_shift.FOR_TRANSFER_LEARNING = True + + def reset_parameters(self) -> None: + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, input: Tensor) -> Tensor: + if self.transfer_mode: + weight = self.weight * self.transfer_scale + self.transfer_shift + else: + weight = self.weight + return F.linear(input, weight, self.bias) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + + +class TransferConvGnLelu(nn.Module): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False): + super().__init__() + padding_map = {1: 0, 3: 1, 5: 2, 7: 3} + assert kernel_size in padding_map.keys() + self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode) + if norm: + self.gn = nn.GroupNorm(num_groups, filters_out) + else: + self.gn = None + if activation: + self.lelu = nn.LeakyReLU(negative_slope=.2) + else: + self.lelu = None + + # Init params. + for m in self.modules(): + if isinstance(m, TransferConv2d): + nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', + nonlinearity='leaky_relu' if self.lelu else 'linear') + m.weight.data *= weight_init_factor + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv(x) + if self.gn: + x = self.gn(x) + if self.lelu: + return self.lelu(x) + else: + return x \ No newline at end of file diff --git a/codes/scripts/tsne_torch.py b/codes/scripts/tsne_torch.py new file mode 100644 index 00000000..e69de29b