Add switched_conv with hard routing and make vqvae use it.
This commit is contained in:
parent
ae4ff4a1e7
commit
51b63b2aa6
136
codes/models/switched_conv_hard_routing.py
Normal file
136
codes/models/switched_conv_hard_routing.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import switched_conv_cuda_naive
|
||||||
|
from lambda_networks import LambdaLayer
|
||||||
|
from torch.nn import init, Conv2d, MSELoss
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchedConvHardRoutingFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input, selector, weight, bias, stride=1):
|
||||||
|
# Build hard attention mask from selector input
|
||||||
|
b, s, h, w = selector.shape
|
||||||
|
selector_mask = (selector.max(dim=1, keepdim=True)[0].repeat(1,s,1,1) == selector).float()
|
||||||
|
mask = selector_mask.argmax(dim=1).int()
|
||||||
|
|
||||||
|
# Compute the convolution using the mask.
|
||||||
|
outputs = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
|
||||||
|
ctx.stride = stride
|
||||||
|
ctx.breadth = s
|
||||||
|
ctx.save_for_backward(*[input, mask, weight, bias])
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
input, mask, weight, bias = ctx.saved_tensors
|
||||||
|
|
||||||
|
# Get the grads for the convolution.
|
||||||
|
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
|
||||||
|
|
||||||
|
# Get the selector grads
|
||||||
|
selector_mask = torch.eye(ctx.breadth, device=input.device)[mask.long()].permute(0,3,1,2).unsqueeze(2) # Note that this is not necessarily equivalent to the selector_mask from above, because under certain circumstances, two values could take on the value '1' in the above instance, whereas this is a true one-hot representation.
|
||||||
|
grad_sel = ((grad * input).unsqueeze(1) * selector_mask).sum(2)
|
||||||
|
return grad, grad_sel, grad_w, grad_b, None
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchedConvHardRouting(nn.Module):
|
||||||
|
def __init__(self, in_c, out_c, kernel_sz, breadth, stride=1, bias=True, dropout_rate=0.0,
|
||||||
|
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
||||||
|
coupler_mode: str = 'standard',
|
||||||
|
coupler_dim_in: int = 0,):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_c
|
||||||
|
self.out_channels = out_c
|
||||||
|
self.kernel_size = kernel_sz
|
||||||
|
self.stride = stride
|
||||||
|
self.has_bias = bias
|
||||||
|
self.breadth = breadth
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
|
||||||
|
if include_coupler:
|
||||||
|
if coupler_mode == 'standard':
|
||||||
|
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1)
|
||||||
|
elif coupler_mode == 'lambda':
|
||||||
|
self.coupler = LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1)
|
||||||
|
else:
|
||||||
|
self.coupler = None
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.empty(out_c))
|
||||||
|
else:
|
||||||
|
self.bias = torch.zeros(out_c)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||||
|
if self.bias is not None:
|
||||||
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight[:,:,0,:,:])
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
init.uniform_(self.bias, -bound, bound)
|
||||||
|
|
||||||
|
def load_weights_from_conv(self, cnv):
|
||||||
|
sd = cnv.state_dict()
|
||||||
|
sd['weight'] = sd['weight'].unsqueeze(2).repeat(1,1,self.breadth,1,1)
|
||||||
|
self.load_state_dict(sd)
|
||||||
|
|
||||||
|
def forward(self, input, selector=None):
|
||||||
|
if self.bias.device != input.device:
|
||||||
|
self.bias = self.bias.to(input.device) # Because this bias can be a tensor that is not moved with the rest of the module.
|
||||||
|
|
||||||
|
# If a coupler was specified, run that to convert selector into a softmax distribution.
|
||||||
|
if self.coupler:
|
||||||
|
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
|
||||||
|
selector = input
|
||||||
|
selector = F.softmax(self.coupler(selector), dim=1)
|
||||||
|
self.last_select = selector.detach().clone()
|
||||||
|
assert selector is not None
|
||||||
|
|
||||||
|
# Apply dropout at the batch level per kernel.
|
||||||
|
if self.training and self.dropout_rate > 0:
|
||||||
|
b, c, h, w = selector.shape
|
||||||
|
drop = torch.rand((b, c, 1, 1), device=input.device) > self.dropout_rate
|
||||||
|
# Ensure that there is always at least one switch left un-dropped out
|
||||||
|
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, c, 1, 1)
|
||||||
|
drop = drop.logical_or(fix_blank)
|
||||||
|
selector = drop * selector
|
||||||
|
|
||||||
|
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
|
||||||
|
|
||||||
|
|
||||||
|
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
|
||||||
|
# with the equivalent SwitchedConv.weight parameters. Does not create coupler params.
|
||||||
|
def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_list=[]):
|
||||||
|
state_dict = module.state_dict()
|
||||||
|
for name, m in module.named_modules():
|
||||||
|
ignored = False
|
||||||
|
for smod in ignore_list:
|
||||||
|
if smod in name:
|
||||||
|
ignored = True
|
||||||
|
continue
|
||||||
|
if ignored:
|
||||||
|
continue
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
state_dict[f'{name}.weight'] = state_dict[f'{name}.weight'].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_net():
|
||||||
|
for j in tqdm(range(100)):
|
||||||
|
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
|
||||||
|
mod_conv = SwitchedConvHardRouting(32, 64, 3, breadth=8, stride=2, bias=True, include_coupler=True, coupler_dim_in=32, dropout_rate=.2).to('cuda')
|
||||||
|
mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8)
|
||||||
|
mod_conv.load_state_dict(mod_sd, strict=False)
|
||||||
|
inp = torch.randn((128,32,128,128), device='cuda')
|
||||||
|
out1 = base_conv(inp)
|
||||||
|
out2 = mod_conv(inp, None)
|
||||||
|
compare = (out2+torch.rand_like(out2)*1e-6).detach()
|
||||||
|
MSELoss()(out2, compare).backward()
|
||||||
|
assert(torch.max(torch.abs(out1-out2)) < 1e-5)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_net()
|
|
@ -7,7 +7,8 @@ from torch.nn import functional as F
|
||||||
|
|
||||||
import torch.distributed as distributed
|
import torch.distributed as distributed
|
||||||
|
|
||||||
from models.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv
|
from models.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
||||||
|
convert_conv_net_state_dict_to_switched_conv
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, opt_get
|
from utils.util import checkpoint, opt_get
|
||||||
|
|
||||||
|
@ -16,7 +17,7 @@ from utils.util import checkpoint, opt_get
|
||||||
class UpsampleConv(nn.Module):
|
class UpsampleConv(nn.Module):
|
||||||
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
|
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = SwitchedConv(in_filters, out_filters, kernel_size, breadth, padding=padding, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters)
|
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters, dropout_rate=0.2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||||
|
@ -83,9 +84,9 @@ class ResBlock(nn.Module):
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
|
nn.Conv2d(in_channel, channel, 3, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
SwitchedConv(channel, in_channel, 1, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel),
|
nn.Conv2d(channel, in_channel, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
@ -101,18 +102,18 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
if stride == 4:
|
if stride == 4:
|
||||||
blocks = [
|
blocks = [
|
||||||
SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
|
SwitchedConvHardRouting(in_channel, channel // 2, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
SwitchedConv(channel // 2, channel, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2),
|
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
SwitchedConv(channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel),
|
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel, dropout_rate=0.2),
|
||||||
]
|
]
|
||||||
|
|
||||||
elif stride == 2:
|
elif stride == 2:
|
||||||
blocks = [
|
blocks = [
|
||||||
SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
|
SwitchedConvHardRouting(in_channel, channel // 2, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
SwitchedConv(channel // 2, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2),
|
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.2),
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(n_res_block):
|
for i in range(n_res_block):
|
||||||
|
@ -132,7 +133,7 @@ class Decoder(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
blocks = [SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel)]
|
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.2)]
|
||||||
|
|
||||||
for i in range(n_res_block):
|
for i in range(n_res_block):
|
||||||
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
||||||
|
@ -171,7 +172,7 @@ class VQVAE(nn.Module):
|
||||||
codebook_dim=64,
|
codebook_dim=64,
|
||||||
codebook_size=512,
|
codebook_size=512,
|
||||||
decay=0.99,
|
decay=0.99,
|
||||||
breadth=4,
|
breadth=8,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -260,7 +261,8 @@ def convert_weights(weights_file):
|
||||||
import models.vqvae.vqvae_no_conv_transpose as stdvq
|
import models.vqvae.vqvae_no_conv_transpose as stdvq
|
||||||
std_model = stdvq.VQVAE()
|
std_model = stdvq.VQVAE()
|
||||||
std_model.load_state_dict(sd)
|
std_model.load_state_dict(sd)
|
||||||
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 4, ['quantize_conv_t', 'quantize_conv_b'])
|
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 1, ['quantize_conv_t', 'quantize_conv_b',
|
||||||
|
'conv.1', 'conv.3'])
|
||||||
torch.save(nsd, "converted.pth")
|
torch.save(nsd, "converted.pth")
|
||||||
|
|
||||||
|
|
||||||
|
@ -271,6 +273,6 @@ def register_vqvae_norm_switched_conv_lambda(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#v = VQVAE()
|
v = VQVAE(breadth=8).cuda()
|
||||||
#print(v(torch.randn(1,3,128,128))[0].shape)
|
print(v(torch.randn(1,3,128,128).cuda())[0].shape)
|
||||||
convert_weights("../../../experiments/4000_generator.pth")
|
#convert_weights("../../../experiments/50000_generator.pth")
|
||||||
|
|
|
@ -295,7 +295,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_tiled_nvqvae_stage1.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_tiled_nvqvae_stage1_lambda.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user