Hard Routing mods

- Turns out my custom convolution was RIDDLED with backwards bugs, which is
   why the existing implementation wasn't working so well.
- Implements the switch logic from both Mixture of Experts and Switch Transformers
  for testing purposes.
This commit is contained in:
James Betker 2021-02-02 20:35:58 -07:00
parent 29c1c3bede
commit 0dca36946f
4 changed files with 223 additions and 95 deletions

View File

@ -9,35 +9,81 @@ import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
import torch.distributed as dist import torch.distributed as dist
from trainer.losses import ConfigurableLoss
def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1):
convs = []
b, s, h, w = selector.shape
for sel in range(s):
convs.append(F.conv2d(input, weight[:, :, sel, :, :], bias, stride=stride, padding=weight.shape[-1] // 2))
output = torch.stack(convs, dim=1) * selector.unsqueeze(dim=2)
return output.sum(dim=1)
class SwitchedConvHardRoutingFunction(torch.autograd.Function): class SwitchedConvHardRoutingFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, selector, weight, bias, stride=1): def forward(ctx, input, selector, weight, bias, stride=1):
# Build hard attention mask from selector input # Build hard attention mask from selector input
b, s, h, w = selector.shape b, s, h, w = selector.shape
selector_mask = (selector.max(dim=1, keepdim=True)[0].repeat(1,s,1,1) == selector).float()
mask = selector_mask.argmax(dim=1).int()
# Compute the convolution using the mask. mask = selector.argmax(dim=1).int()
outputs = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride) output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
ctx.stride = stride ctx.stride = stride
ctx.breadth = s ctx.breadth = s
ctx.save_for_backward(*[input, mask, weight, bias]) ctx.save_for_backward(*[input, mask, weight, bias])
return outputs return output
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad):
input, mask, weight, bias = ctx.saved_tensors input, mask, weight, bias = ctx.saved_tensors
grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
# Get the grads for the convolution.
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
# Get the selector grads
selector_mask = torch.eye(ctx.breadth, device=input.device)[mask.long()].permute(0,3,1,2).unsqueeze(2) # Note that this is not necessarily equivalent to the selector_mask from above, because under certain circumstances, two values could take on the value '1' in the above instance, whereas this is a true one-hot representation.
grad_sel = ((grad * input).unsqueeze(1) * selector_mask).sum(2)
return grad, grad_sel, grad_w, grad_b, None return grad, grad_sel, grad_w, grad_b, None
# Implements KeepTopK where k=1 from mixture of experts paper.
class KeepTop1(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
input[mask != 1] = -float('inf')
ctx.save_for_backward(mask)
return input
@staticmethod
def backward(ctx, grad):
import pydevd
pydevd.settrace(suspend=False, trace_only_current_thread=True)
mask = ctx.saved_tensors
grad_input = grad.clone()
grad_input[mask != 1] = 0
return grad_input
class RouteTop1(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
out = torch.ones_like(input)
out[mask != 1] = 0
ctx.save_for_backward(mask, input.clone())
return out
@staticmethod
def backward(ctx, grad):
# Enable breakpoints in this function: (Comment out if not debugging)
#import pydevd
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
mask, input = ctx.saved_tensors
input[mask != 1] = 1
grad_input = grad.clone()
grad_input[mask != 1] = 0
grad_input_n = grad_input / input # Above, we made everything either a zero or a one. Unscale the ones by dividing by the unmasked inputs.
return grad_input_n
""" """
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
@ -109,6 +155,107 @@ class SwitchNorm(nn.Module):
return x / x.sum(dim=1, keepdim=True) return x / x.sum(dim=1, keepdim=True)
class MixtureOfExperts2dRouter(nn.Module):
def __init__(self, num_experts):
super().__init__()
self.num_experts = num_experts
self.wnoise = nn.Parameter(torch.zeros(1,num_experts,1,1))
self.wg = nn.Parameter(torch.zeros(1,num_experts,1,1))
def forward(self, x):
wg = x * self.wg
wnoise = nn.functional.softplus(x * self.wnoise)
H = wg + torch.randn_like(x) * wnoise
# Produce the load-balancing loss.
eye = torch.eye(self.num_experts, device=x.device).view(1,self.num_experts,self.num_experts,1,1)
mask=torch.abs(1-eye)
b,c,h,w=H.shape
ninf = torch.zeros_like(eye)
ninf[eye==1] = -float('inf')
H_masked=H.view(b,c,1,h,w)*mask+ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes.
max_excluding=torch.max(H_masked,dim=2)[0]
# load_loss and G are stored as local members to facilitate their use by hard routing regularization losses.
# this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below.
self.load_loss = torch.erf((wg - max_excluding)/wnoise)
#self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so:
self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream.
return self.G
# Retrieve the locally stored loss values and delete them from membership (so as to not waste memory)
def reset(self):
G, load = self.G, self.load_loss
del self.G
del self.load_loss
return G, load
# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses.
class MixtureOfExpertsLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.routers = [] # This is filled in during the first forward() pass and cached from there.
self.first_forward_encountered = False
self.load_weight = opt['load_weight']
self.importance_weight = opt['importance_weight']
def forward(self, net, state):
if not self.first_forward_encountered:
for m in net.modules():
if isinstance(m, MixtureOfExperts2dRouter):
self.routers.append(m)
self.first_forward_encountered = True
l_importance = 0
l_load = 0
for r in self.routers:
G, L = r.reset()
l_importance += G.var().square()
l_load += L.var().square()
return l_importance * self.importance_weight + l_load * self.load_weight
class SwitchTransformersLoadBalancer(nn.Module):
def __init__(self):
super().__init__()
self.norm = SwitchNorm(8, accumulator_size=256)
def forward(self, x):
self.soft = self.norm(nn.functional.softmax(x, dim=1))
self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream.
return self.hard
def reset(self):
soft, hard = self.soft, self.hard
del self.soft, self.hard
return soft, hard
class SwitchTransformersLoadBalancingLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.routers = [] # This is filled in during the first forward() pass and cached from there.
self.first_forward_encountered = False
def forward(self, net, state):
if not self.first_forward_encountered:
for m in net.modules():
if isinstance(m, SwitchTransformersLoadBalancer):
self.routers.append(m)
self.first_forward_encountered = True
loss = 0
for r in self.routers:
soft, hard = r.reset()
N = hard.shape[1]
h_mean = hard.mean(dim=[0,2,3])
s_mean = soft.mean(dim=[0,2,3])
loss += torch.dot(h_mean, s_mean) * N
return loss
class SwitchedConvHardRouting(nn.Module): class SwitchedConvHardRouting(nn.Module):
def __init__(self, def __init__(self,
in_c, in_c,
@ -120,8 +267,7 @@ class SwitchedConvHardRouting(nn.Module):
dropout_rate=0.0, dropout_rate=0.0,
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate. include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
coupler_mode: str = 'standard', coupler_mode: str = 'standard',
coupler_dim_in: int = 0, coupler_dim_in: int = 0):
switch_norm: bool = True):
super().__init__() super().__init__()
self.in_channels = in_c self.in_channels = in_c
self.out_channels = out_c self.out_channels = out_c
@ -130,14 +276,10 @@ class SwitchedConvHardRouting(nn.Module):
self.has_bias = bias self.has_bias = bias
self.breadth = breadth self.breadth = breadth
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
if switch_norm:
self.switch_norm = SwitchNorm(breadth, accumulator_size=512)
else:
self.switch_norm = None
if include_coupler: if include_coupler:
if coupler_mode == 'standard': if coupler_mode == 'standard':
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1) self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
elif coupler_mode == 'lambda': elif coupler_mode == 'lambda':
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
nn.BatchNorm2d(coupler_dim_in), nn.BatchNorm2d(coupler_dim_in),
@ -145,9 +287,11 @@ class SwitchedConvHardRouting(nn.Module):
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.BatchNorm2d(breadth), nn.BatchNorm2d(breadth),
nn.ReLU(), nn.ReLU(),
Conv2d(breadth, breadth, 1)) Conv2d(breadth, breadth, 1, stride=self.stride))
else: else:
self.coupler = None self.coupler = None
#self.gate = MixtureOfExperts2dRouter(breadth)
self.gate = SwitchTransformersLoadBalancer()
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz)) self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
if bias: if bias:
@ -175,14 +319,10 @@ class SwitchedConvHardRouting(nn.Module):
# If a coupler was specified, run that to convert selector into a softmax distribution. # If a coupler was specified, run that to convert selector into a softmax distribution.
if self.coupler: if self.coupler:
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed. if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
selector = input.detach() selector = input
selector = F.softmax(self.coupler(selector), dim=1) selector = self.coupler(selector)
assert selector is not None assert selector is not None
# Perform normalization on the selector if applicable.
if self.switch_norm:
selector = self.switch_norm(selector)
# Apply dropout at the batch level per kernel. # Apply dropout at the batch level per kernel.
if self.training and self.dropout_rate > 0: if self.training and self.dropout_rate > 0:
b, c, h, w = selector.shape b, c, h, w = selector.shape
@ -192,11 +332,18 @@ class SwitchedConvHardRouting(nn.Module):
drop = drop.logical_or(fix_blank) drop = drop.logical_or(fix_blank)
selector = drop * selector selector = drop * selector
selector = self.gate(selector)
# Debugging variables # Debugging variables
self.last_select = selector.detach().clone() self.last_select = selector.detach().clone()
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1) self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride) if False:
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
else:
# This composes the switching functionality using raw Torch, which basically consists of computing each of <breadth> convs separately and combining them.
return SwitchedConvRoutingNormal(input, selector, self.weight, self.bias, self.stride)
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them # Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
@ -213,7 +360,11 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_
continue continue
if ignored: if ignored:
continue continue
state_dict[f'{name}.weight'] = state_dict[f'{name}.weight'].unsqueeze(2).repeat(1,1,switch_breadth,1,1) if name == '':
key = 'weight'
else:
key = f'{name}.weight'
state_dict[key] = state_dict[key].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
return state_dict return state_dict

View File

@ -17,7 +17,7 @@ from utils.util import checkpoint, opt_get
class UpsampleConv(nn.Module): class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding): def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
super().__init__() super().__init__()
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters, dropout_rate=0.4) self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_filters, dropout_rate=0.4)
def forward(self, x): def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2) up = torch.nn.functional.interpolate(x, scale_factor=2)
@ -104,16 +104,16 @@ class Encoder(nn.Module):
blocks = [ blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4), SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel, dropout_rate=0.4), SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel, dropout_rate=0.4),
] ]
elif stride == 2: elif stride == 2:
blocks = [ blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4), SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
] ]
for i in range(n_res_block): for i in range(n_res_block):
@ -133,7 +133,7 @@ class Decoder(nn.Module):
): ):
super().__init__() super().__init__()
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.4)] blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_channel, dropout_rate=0.4)]
for i in range(n_res_block): for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel, breadth)) blocks.append(ResBlock(channel, n_res_channel, breadth))

View File

@ -1,78 +1,49 @@
import os.path as osp
import logging
import time
import argparse
import os import os
import shutil
import utils from torch.utils.data import DataLoader
from trainer.ExtensibleTrainer import ExtensibleTrainer
from trainer.networks import define_F from data.single_image_dataset import SingleImageDataset
from utils import options as option
import utils.util as util
from data import create_dataset, create_dataloader
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from models.vqvae.vqvae_no_conv_transpose import VQVAE
if __name__ == "__main__": if __name__ == "__main__":
bin_path = "f:\\binned" bin_path = "f:\\binned"
good_path = "f:\\good" good_path = "f:\\good"
os.makedirs(bin_path, exist_ok=True) os.makedirs(bin_path, exist_ok=True)
os.makedirs(good_path, exist_ok=True) os.makedirs(good_path, exist_ok=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/generator_filter.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
opt['dist'] = False
util.mkdirs( model = VQVAE().cuda()
(path for key, path in opt['path'].items() model.load_state_dict(torch.load('../experiments/nvqvae_imgset.pth'))
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) ds = SingleImageDataset({
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 'name': 'amalgam',
screen=True, tofile=True) 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5'],
logger = logging.getLogger('base') 'weights': [1],
logger.info(option.dict2str(opt)) 'target_size': 128,
'force_multiple': 32,
'scale': 1,
'eval': False
})
dl = DataLoader(ds, batch_size=256, num_workers=1)
#### Create test dataset and dataloader means = []
test_loaders = [] model.eval()
for phase, dataset_opt in sorted(opt['datasets'].items()): with torch.no_grad():
test_set = create_dataset(dataset_opt) for i, data in enumerate(tqdm(dl)):
test_loader = create_dataloader(test_set, dataset_opt, opt=opt) hq = data['hq'].cuda()
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) gen = model(hq)[0]
test_loaders.append(test_loader) l2 = torch.mean(torch.square(hq - gen), dim=[1,2,3])
for b in range(len(l2)):
if l2[b] > .0004:
shutil.copy(data['GT_path'][b], good_path)
#else:
# shutil.copy(data['GT_path'][b], bin_path)
model = ExtensibleTrainer(opt)
utils.util.loaded_options = opt
fea_loss = 0
for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name']
logger.info('\nTesting [{:s}]...'.format(test_set_name))
test_start_time = time.time()
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
util.mkdir(dataset_dir)
netF = define_F(which_model='vgg').to(model.env['device'])
tq = tqdm(test_loader) #means.append(l2.cpu())
removed = 0 #if i % 10 == 0:
means = [] # print(torch.stack(means, dim=0).mean())
for data in tq:
model.feed_data(data, need_GT=True)
model.test()
gen = model.eval_state['gen'][0].to(model.env['device'])
feagen = netF(gen)
feareal = netF(data['hq'].to(model.env['device']))
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
means.append(torch.mean(losses).item())
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
for i in range(losses.shape[0]):
if losses[i] < 25000:
os.remove(data['GT_path'][i])
removed += 1
#imname = osp.basename(data['GT_path'][i])
#if losses[i] < 25000:
# torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -47,6 +47,12 @@ def create_loss(opt_loss, env):
return RecurrentLoss(opt_loss, env) return RecurrentLoss(opt_loss, env)
elif type == 'for_element': elif type == 'for_element':
return ForElementLoss(opt_loss, env) return ForElementLoss(opt_loss, env)
elif type == 'mixture_of_experts':
from models.switched_conv_hard_routing import MixtureOfExpertsLoss
return MixtureOfExpertsLoss(opt_loss, env)
elif type == 'switch_transformer_balance':
from models.switched_conv_hard_routing import SwitchTransformersLoadBalancingLoss
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
else: else:
raise NotImplementedError raise NotImplementedError