diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index 689504da..4a48e650 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -2,9 +2,8 @@ import math import torch import torch.nn as nn -import switched_conv_cuda_naive from lambda_networks import LambdaLayer -from torch.nn import init, Conv2d, MSELoss +from torch.nn import init, Conv2d, MSELoss, ZeroPad2d import torch.nn.functional as F from tqdm import tqdm import torch.distributed as dist @@ -24,10 +23,14 @@ def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1): class SwitchedConvHardRoutingFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, selector, weight, bias, stride=1): + # Pre-pad the input. + input = ZeroPad2d(weight.shape[-1]//2)(input) + # Build hard attention mask from selector input b, s, h, w = selector.shape mask = selector.argmax(dim=1).int() + import switched_conv_cuda_naive output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride) ctx.stride = stride @@ -47,7 +50,13 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function): # and zeros that is multiplied by the output.) grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1) + import switched_conv_cuda_naive grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride) + + # Remove input padding from grad + padding = weight.shape[-1] // 2 + if padding > 0: + grad = grad[:,:,padding:-padding,padding:-padding] return grad, grad_sel, grad_w, grad_b, None @@ -204,7 +213,8 @@ class SwitchedConvHardRouting(nn.Module): Conv2d(breadth, breadth, 1, stride=self.stride)) else: self.coupler = None - self.gate = HardRoutingGate(breadth, hard_en=hard_en) + self.gate = HardRoutingGate(breadth, hard_en=True) + self.hard_en = hard_en self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz)) if bias: @@ -251,7 +261,7 @@ class SwitchedConvHardRouting(nn.Module): self.last_select = selector.detach().clone() self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1) - if False: + if self.hard_en: # This is a custom CUDA implementation which should be faster and less memory intensive (once completed). return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride) else: diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py index ff6a950e..fe118407 100644 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ b/codes/models/vqvae/vqvae_3_hardswitch.py @@ -232,27 +232,41 @@ def register_vqvae3_hard_switch(opt_net, opt): def performance_test(): # For breadth=32: - # Custom_cuda_naive: 28.9s + # Custom_cuda_naive: 15.4 # Torch_native: 29.2s # # For breadth=8 - # Custom_cuda_naive: 18.4s + # Custom_cuda_naive: 9.8 # Torch_native: 10s cfg = { 'mode': 'lambda', - 'breadth': 8, + 'breadth': 16, 'hard_enabled': True, - 'dropout': 0.4 + 'dropout': 0, } - net = VQVAE3HardSwitch(cfg=cfg).to('cuda') + net = VQVAE3HardSwitch(cfg=cfg).to('cuda').double() + cfg['hard_enabled'] = False + netO = VQVAE3HardSwitch(cfg=cfg).double() + netO.load_state_dict(net.state_dict()) + netO = netO.cpu() + loss = nn.L1Loss() opt = torch.optim.Adam(net.parameters(), lr=1e-4) started = time() for j in tqdm(range(10)): - inp = torch.rand((8, 3, 256, 256), device='cuda') + inp = torch.rand((4, 3, 64, 64), device='cuda', dtype=torch.double) res = net(inp)[0] l = loss(res, inp) l.backward() + + res2 = netO(inp.cpu())[0] + l = loss(res2, inp.cpu()) + l.backward() + + for p, op in zip(net.parameters(), netO.parameters()): + diff = p.grad.cpu() - op.grad + print(diff.max()) + opt.step() net.zero_grad() print("Elapsed: ", (time()-started)) diff --git a/codes/models/vqvae/vqvae_3_separated_coupler.py b/codes/models/vqvae/vqvae_3_separated_coupler.py new file mode 100644 index 00000000..dc3f92cb --- /dev/null +++ b/codes/models/vqvae/vqvae_3_separated_coupler.py @@ -0,0 +1,180 @@ +import torch +from kornia import filter2D +from torch import nn +from torch.nn import functional as F + +import torch.distributed as distributed + +from models.vqvae.vqvae import ResBlock, Quantize +from trainer.networks import register_model +from utils.util import checkpoint, opt_get + + +# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. +class UpsampleConv(nn.Module): + def __init__(self, in_filters, out_filters, kernel_size, padding): + super().__init__() + self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding) + + def forward(self, x): + up = torch.nn.functional.interpolate(x, scale_factor=2) + return self.conv(up) + + +class Encoder(nn.Module): + def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride): + super().__init__() + + if stride == 4: + blocks = [ + nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), + nn.LeakyReLU(inplace=True), + nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2), + nn.LeakyReLU(inplace=True), + nn.Conv2d(channel, channel, 3, padding=1), + ] + + elif stride == 2: + blocks = [ + nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), + nn.LeakyReLU(inplace=True), + nn.Conv2d(channel // 2, channel, 3, padding=1), + ] + + for i in range(n_res_block): + blocks.append(ResBlock(channel, n_res_channel)) + + blocks.append(nn.LeakyReLU(inplace=True)) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, input): + return self.blocks(input) + + +class Decoder(nn.Module): + def __init__( + self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride + ): + super().__init__() + + blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] + + for i in range(n_res_block): + blocks.append(ResBlock(channel, n_res_channel)) + + blocks.append(nn.LeakyReLU(inplace=True)) + + if stride == 4: + blocks.extend( + [ + UpsampleConv(channel, channel // 2, 5, padding=2), + nn.LeakyReLU(inplace=True), + UpsampleConv( + channel // 2, out_channel, 5, padding=2 + ), + ] + ) + + elif stride == 2: + blocks.append( + UpsampleConv(channel, out_channel, 5, padding=2) + ) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, input): + return self.blocks(input) + + +class VQVAE3(nn.Module): + def __init__( + self, + in_channel=3, + channel=128, + n_res_block=2, + n_res_channel=32, + codebook_dim=64, + codebook_size=512, + decay=0.99, + ): + super().__init__() + + self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1), + nn.LeakyReLU(inplace=True)]) + self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4) + self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) + self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) + self.quantize_t = Quantize(codebook_dim, codebook_size) + self.dec_t = Decoder( + codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2 + ) + self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) + self.quantize_b = Quantize(codebook_dim, codebook_size) + self.upsample_t = UpsampleConv( + codebook_dim, codebook_dim, 5, padding=2 + ) + self.dec = Decoder( + codebook_dim + codebook_dim, + 32, + channel, + n_res_block, + n_res_channel, + stride=4, + ) + self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1) + + def forward(self, input): + quant_t, quant_b, diff, _, _ = self.encode(input) + dec = self.decode(quant_t, quant_b) + + return dec, diff + + def encode(self, input): + fea = self.initial_conv(input) + enc_b = checkpoint(self.enc_b, fea) + enc_t = checkpoint(self.enc_t, enc_b) + + quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) + quant_t, diff_t, id_t = self.quantize_t(quant_t) + quant_t = quant_t.permute(0, 3, 1, 2) + diff_t = diff_t.unsqueeze(0) + + dec_t = checkpoint(self.dec_t, quant_t) + enc_b = torch.cat([dec_t, enc_b], 1) + + quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) + quant_b, diff_b, id_b = self.quantize_b(quant_b) + quant_b = quant_b.permute(0, 3, 1, 2) + diff_b = diff_b.unsqueeze(0) + + return quant_t, quant_b, diff_t + diff_b, id_t, id_b + + def decode(self, quant_t, quant_b): + upsample_t = self.upsample_t(quant_t) + quant = torch.cat([upsample_t, quant_b], 1) + dec = checkpoint(self.dec, quant) + dec = checkpoint(self.final_conv, dec) + + return dec + + def decode_code(self, code_t, code_b): + quant_t = self.quantize_t.embed_code(code_t) + quant_t = quant_t.permute(0, 3, 1, 2) + quant_b = self.quantize_b.embed_code(code_b) + quant_b = quant_b.permute(0, 3, 1, 2) + + dec = self.decode(quant_t, quant_b) + + return dec + + +@register_model +def register_vqvae3(opt_net, opt): + kw = opt_get(opt_net, ['kwargs'], {}) + return VQVAE3(**kw) + + +if __name__ == '__main__': + v = VQVAE3() + print(v(torch.randn(1,3,128,128))[0].shape) diff --git a/codes/requirements.txt b/codes/requirements.txt index 03656bb3..5a8d5955 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -17,4 +17,5 @@ linear_attention_transformer vector_quantize_pytorch orjson einops -gsa-pytorch \ No newline at end of file +gsa-pytorch +lambda-networks \ No newline at end of file