Hard Routing mods

- Turns out my custom convolution was RIDDLED with backwards bugs, which is
   why the existing implementation wasn't working so well.
- Implements the switch logic from both Mixture of Experts and Switch Transformers
  for testing purposes.
This commit is contained in:
James Betker 2021-02-02 20:35:58 -07:00
parent 29c1c3bede
commit 0dca36946f
4 changed files with 223 additions and 95 deletions

View File

@ -9,35 +9,81 @@ import torch.nn.functional as F
from tqdm import tqdm
import torch.distributed as dist
from trainer.losses import ConfigurableLoss
def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1):
convs = []
b, s, h, w = selector.shape
for sel in range(s):
convs.append(F.conv2d(input, weight[:, :, sel, :, :], bias, stride=stride, padding=weight.shape[-1] // 2))
output = torch.stack(convs, dim=1) * selector.unsqueeze(dim=2)
return output.sum(dim=1)
class SwitchedConvHardRoutingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, selector, weight, bias, stride=1):
# Build hard attention mask from selector input
b, s, h, w = selector.shape
selector_mask = (selector.max(dim=1, keepdim=True)[0].repeat(1,s,1,1) == selector).float()
mask = selector_mask.argmax(dim=1).int()
# Compute the convolution using the mask.
outputs = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
mask = selector.argmax(dim=1).int()
output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
ctx.stride = stride
ctx.breadth = s
ctx.save_for_backward(*[input, mask, weight, bias])
return outputs
return output
@staticmethod
def backward(ctx, grad):
input, mask, weight, bias = ctx.saved_tensors
# Get the grads for the convolution.
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
# Get the selector grads
selector_mask = torch.eye(ctx.breadth, device=input.device)[mask.long()].permute(0,3,1,2).unsqueeze(2) # Note that this is not necessarily equivalent to the selector_mask from above, because under certain circumstances, two values could take on the value '1' in the above instance, whereas this is a true one-hot representation.
grad_sel = ((grad * input).unsqueeze(1) * selector_mask).sum(2)
grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
return grad, grad_sel, grad_w, grad_b, None
# Implements KeepTopK where k=1 from mixture of experts paper.
class KeepTop1(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
input[mask != 1] = -float('inf')
ctx.save_for_backward(mask)
return input
@staticmethod
def backward(ctx, grad):
import pydevd
pydevd.settrace(suspend=False, trace_only_current_thread=True)
mask = ctx.saved_tensors
grad_input = grad.clone()
grad_input[mask != 1] = 0
return grad_input
class RouteTop1(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
out = torch.ones_like(input)
out[mask != 1] = 0
ctx.save_for_backward(mask, input.clone())
return out
@staticmethod
def backward(ctx, grad):
# Enable breakpoints in this function: (Comment out if not debugging)
#import pydevd
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
mask, input = ctx.saved_tensors
input[mask != 1] = 1
grad_input = grad.clone()
grad_input[mask != 1] = 0
grad_input_n = grad_input / input # Above, we made everything either a zero or a one. Unscale the ones by dividing by the unmasked inputs.
return grad_input_n
"""
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
@ -109,6 +155,107 @@ class SwitchNorm(nn.Module):
return x / x.sum(dim=1, keepdim=True)
class MixtureOfExperts2dRouter(nn.Module):
def __init__(self, num_experts):
super().__init__()
self.num_experts = num_experts
self.wnoise = nn.Parameter(torch.zeros(1,num_experts,1,1))
self.wg = nn.Parameter(torch.zeros(1,num_experts,1,1))
def forward(self, x):
wg = x * self.wg
wnoise = nn.functional.softplus(x * self.wnoise)
H = wg + torch.randn_like(x) * wnoise
# Produce the load-balancing loss.
eye = torch.eye(self.num_experts, device=x.device).view(1,self.num_experts,self.num_experts,1,1)
mask=torch.abs(1-eye)
b,c,h,w=H.shape
ninf = torch.zeros_like(eye)
ninf[eye==1] = -float('inf')
H_masked=H.view(b,c,1,h,w)*mask+ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes.
max_excluding=torch.max(H_masked,dim=2)[0]
# load_loss and G are stored as local members to facilitate their use by hard routing regularization losses.
# this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below.
self.load_loss = torch.erf((wg - max_excluding)/wnoise)
#self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so:
self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream.
return self.G
# Retrieve the locally stored loss values and delete them from membership (so as to not waste memory)
def reset(self):
G, load = self.G, self.load_loss
del self.G
del self.load_loss
return G, load
# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses.
class MixtureOfExpertsLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.routers = [] # This is filled in during the first forward() pass and cached from there.
self.first_forward_encountered = False
self.load_weight = opt['load_weight']
self.importance_weight = opt['importance_weight']
def forward(self, net, state):
if not self.first_forward_encountered:
for m in net.modules():
if isinstance(m, MixtureOfExperts2dRouter):
self.routers.append(m)
self.first_forward_encountered = True
l_importance = 0
l_load = 0
for r in self.routers:
G, L = r.reset()
l_importance += G.var().square()
l_load += L.var().square()
return l_importance * self.importance_weight + l_load * self.load_weight
class SwitchTransformersLoadBalancer(nn.Module):
def __init__(self):
super().__init__()
self.norm = SwitchNorm(8, accumulator_size=256)
def forward(self, x):
self.soft = self.norm(nn.functional.softmax(x, dim=1))
self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream.
return self.hard
def reset(self):
soft, hard = self.soft, self.hard
del self.soft, self.hard
return soft, hard
class SwitchTransformersLoadBalancingLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.routers = [] # This is filled in during the first forward() pass and cached from there.
self.first_forward_encountered = False
def forward(self, net, state):
if not self.first_forward_encountered:
for m in net.modules():
if isinstance(m, SwitchTransformersLoadBalancer):
self.routers.append(m)
self.first_forward_encountered = True
loss = 0
for r in self.routers:
soft, hard = r.reset()
N = hard.shape[1]
h_mean = hard.mean(dim=[0,2,3])
s_mean = soft.mean(dim=[0,2,3])
loss += torch.dot(h_mean, s_mean) * N
return loss
class SwitchedConvHardRouting(nn.Module):
def __init__(self,
in_c,
@ -120,8 +267,7 @@ class SwitchedConvHardRouting(nn.Module):
dropout_rate=0.0,
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
coupler_mode: str = 'standard',
coupler_dim_in: int = 0,
switch_norm: bool = True):
coupler_dim_in: int = 0):
super().__init__()
self.in_channels = in_c
self.out_channels = out_c
@ -130,14 +276,10 @@ class SwitchedConvHardRouting(nn.Module):
self.has_bias = bias
self.breadth = breadth
self.dropout_rate = dropout_rate
if switch_norm:
self.switch_norm = SwitchNorm(breadth, accumulator_size=512)
else:
self.switch_norm = None
if include_coupler:
if coupler_mode == 'standard':
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1)
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
elif coupler_mode == 'lambda':
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
nn.BatchNorm2d(coupler_dim_in),
@ -145,9 +287,11 @@ class SwitchedConvHardRouting(nn.Module):
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.BatchNorm2d(breadth),
nn.ReLU(),
Conv2d(breadth, breadth, 1))
Conv2d(breadth, breadth, 1, stride=self.stride))
else:
self.coupler = None
#self.gate = MixtureOfExperts2dRouter(breadth)
self.gate = SwitchTransformersLoadBalancer()
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
if bias:
@ -175,14 +319,10 @@ class SwitchedConvHardRouting(nn.Module):
# If a coupler was specified, run that to convert selector into a softmax distribution.
if self.coupler:
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
selector = input.detach()
selector = F.softmax(self.coupler(selector), dim=1)
selector = input
selector = self.coupler(selector)
assert selector is not None
# Perform normalization on the selector if applicable.
if self.switch_norm:
selector = self.switch_norm(selector)
# Apply dropout at the batch level per kernel.
if self.training and self.dropout_rate > 0:
b, c, h, w = selector.shape
@ -192,11 +332,18 @@ class SwitchedConvHardRouting(nn.Module):
drop = drop.logical_or(fix_blank)
selector = drop * selector
selector = self.gate(selector)
# Debugging variables
self.last_select = selector.detach().clone()
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
if False:
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
else:
# This composes the switching functionality using raw Torch, which basically consists of computing each of <breadth> convs separately and combining them.
return SwitchedConvRoutingNormal(input, selector, self.weight, self.bias, self.stride)
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
@ -213,7 +360,11 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_
continue
if ignored:
continue
state_dict[f'{name}.weight'] = state_dict[f'{name}.weight'].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
if name == '':
key = 'weight'
else:
key = f'{name}.weight'
state_dict[key] = state_dict[key].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
return state_dict

View File

@ -17,7 +17,7 @@ from utils.util import checkpoint, opt_get
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
super().__init__()
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters, dropout_rate=0.4)
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_filters, dropout_rate=0.4)
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
@ -104,16 +104,16 @@ class Encoder(nn.Module):
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4),
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel, dropout_rate=0.4),
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel, dropout_rate=0.4),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4),
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
]
for i in range(n_res_block):
@ -133,7 +133,7 @@ class Decoder(nn.Module):
):
super().__init__()
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.4)]
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_channel, dropout_rate=0.4)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel, breadth))

View File

@ -1,78 +1,49 @@
import os.path as osp
import logging
import time
import argparse
import os
import shutil
import utils
from trainer.ExtensibleTrainer import ExtensibleTrainer
from trainer.networks import define_F
from utils import options as option
import utils.util as util
from data import create_dataset, create_dataloader
from torch.utils.data import DataLoader
from data.single_image_dataset import SingleImageDataset
from tqdm import tqdm
import torch
from models.vqvae.vqvae_no_conv_transpose import VQVAE
if __name__ == "__main__":
bin_path = "f:\\binned"
good_path = "f:\\good"
os.makedirs(bin_path, exist_ok=True)
os.makedirs(good_path, exist_ok=True)
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/generator_filter.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
opt['dist'] = False
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
model = VQVAE().cuda()
model.load_state_dict(torch.load('../experiments/nvqvae_imgset.pth'))
ds = SingleImageDataset({
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5'],
'weights': [1],
'target_size': 128,
'force_multiple': 32,
'scale': 1,
'eval': False
})
dl = DataLoader(ds, batch_size=256, num_workers=1)
#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
test_set = create_dataset(dataset_opt)
test_loader = create_dataloader(test_set, dataset_opt, opt=opt)
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
means = []
model.eval()
with torch.no_grad():
for i, data in enumerate(tqdm(dl)):
hq = data['hq'].cuda()
gen = model(hq)[0]
l2 = torch.mean(torch.square(hq - gen), dim=[1,2,3])
for b in range(len(l2)):
if l2[b] > .0004:
shutil.copy(data['GT_path'][b], good_path)
#else:
# shutil.copy(data['GT_path'][b], bin_path)
model = ExtensibleTrainer(opt)
utils.util.loaded_options = opt
fea_loss = 0
for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name']
logger.info('\nTesting [{:s}]...'.format(test_set_name))
test_start_time = time.time()
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
util.mkdir(dataset_dir)
netF = define_F(which_model='vgg').to(model.env['device'])
tq = tqdm(test_loader)
removed = 0
means = []
for data in tq:
model.feed_data(data, need_GT=True)
model.test()
gen = model.eval_state['gen'][0].to(model.env['device'])
feagen = netF(gen)
feareal = netF(data['hq'].to(model.env['device']))
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
means.append(torch.mean(losses).item())
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
for i in range(losses.shape[0]):
if losses[i] < 25000:
os.remove(data['GT_path'][i])
removed += 1
#imname = osp.basename(data['GT_path'][i])
#if losses[i] < 25000:
# torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set)))
#means.append(l2.cpu())
#if i % 10 == 0:
# print(torch.stack(means, dim=0).mean())

View File

@ -47,6 +47,12 @@ def create_loss(opt_loss, env):
return RecurrentLoss(opt_loss, env)
elif type == 'for_element':
return ForElementLoss(opt_loss, env)
elif type == 'mixture_of_experts':
from models.switched_conv_hard_routing import MixtureOfExpertsLoss
return MixtureOfExpertsLoss(opt_loss, env)
elif type == 'switch_transformer_balance':
from models.switched_conv_hard_routing import SwitchTransformersLoadBalancingLoss
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
else:
raise NotImplementedError