Hard Routing mods
- Turns out my custom convolution was RIDDLED with backwards bugs, which is why the existing implementation wasn't working so well. - Implements the switch logic from both Mixture of Experts and Switch Transformers for testing purposes.
This commit is contained in:
parent
29c1c3bede
commit
0dca36946f
|
@ -9,35 +9,81 @@ import torch.nn.functional as F
|
|||
from tqdm import tqdm
|
||||
import torch.distributed as dist
|
||||
|
||||
from trainer.losses import ConfigurableLoss
|
||||
|
||||
|
||||
def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1):
|
||||
convs = []
|
||||
b, s, h, w = selector.shape
|
||||
for sel in range(s):
|
||||
convs.append(F.conv2d(input, weight[:, :, sel, :, :], bias, stride=stride, padding=weight.shape[-1] // 2))
|
||||
output = torch.stack(convs, dim=1) * selector.unsqueeze(dim=2)
|
||||
return output.sum(dim=1)
|
||||
|
||||
|
||||
class SwitchedConvHardRoutingFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, selector, weight, bias, stride=1):
|
||||
# Build hard attention mask from selector input
|
||||
b, s, h, w = selector.shape
|
||||
selector_mask = (selector.max(dim=1, keepdim=True)[0].repeat(1,s,1,1) == selector).float()
|
||||
mask = selector_mask.argmax(dim=1).int()
|
||||
|
||||
# Compute the convolution using the mask.
|
||||
outputs = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
|
||||
mask = selector.argmax(dim=1).int()
|
||||
output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
|
||||
|
||||
ctx.stride = stride
|
||||
ctx.breadth = s
|
||||
ctx.save_for_backward(*[input, mask, weight, bias])
|
||||
return outputs
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
input, mask, weight, bias = ctx.saved_tensors
|
||||
|
||||
# Get the grads for the convolution.
|
||||
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
|
||||
|
||||
# Get the selector grads
|
||||
selector_mask = torch.eye(ctx.breadth, device=input.device)[mask.long()].permute(0,3,1,2).unsqueeze(2) # Note that this is not necessarily equivalent to the selector_mask from above, because under certain circumstances, two values could take on the value '1' in the above instance, whereas this is a true one-hot representation.
|
||||
grad_sel = ((grad * input).unsqueeze(1) * selector_mask).sum(2)
|
||||
grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
|
||||
return grad, grad_sel, grad_w, grad_b, None
|
||||
|
||||
|
||||
# Implements KeepTopK where k=1 from mixture of experts paper.
|
||||
class KeepTop1(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
|
||||
input[mask != 1] = -float('inf')
|
||||
ctx.save_for_backward(mask)
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
import pydevd
|
||||
pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||
mask = ctx.saved_tensors
|
||||
grad_input = grad.clone()
|
||||
grad_input[mask != 1] = 0
|
||||
return grad_input
|
||||
|
||||
|
||||
class RouteTop1(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
|
||||
out = torch.ones_like(input)
|
||||
out[mask != 1] = 0
|
||||
ctx.save_for_backward(mask, input.clone())
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
# Enable breakpoints in this function: (Comment out if not debugging)
|
||||
#import pydevd
|
||||
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||
|
||||
mask, input = ctx.saved_tensors
|
||||
input[mask != 1] = 1
|
||||
grad_input = grad.clone()
|
||||
grad_input[mask != 1] = 0
|
||||
grad_input_n = grad_input / input # Above, we made everything either a zero or a one. Unscale the ones by dividing by the unmasked inputs.
|
||||
return grad_input_n
|
||||
|
||||
|
||||
"""
|
||||
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of
|
||||
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
|
||||
|
@ -109,6 +155,107 @@ class SwitchNorm(nn.Module):
|
|||
return x / x.sum(dim=1, keepdim=True)
|
||||
|
||||
|
||||
class MixtureOfExperts2dRouter(nn.Module):
|
||||
def __init__(self, num_experts):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.wnoise = nn.Parameter(torch.zeros(1,num_experts,1,1))
|
||||
self.wg = nn.Parameter(torch.zeros(1,num_experts,1,1))
|
||||
|
||||
def forward(self, x):
|
||||
wg = x * self.wg
|
||||
wnoise = nn.functional.softplus(x * self.wnoise)
|
||||
H = wg + torch.randn_like(x) * wnoise
|
||||
|
||||
# Produce the load-balancing loss.
|
||||
eye = torch.eye(self.num_experts, device=x.device).view(1,self.num_experts,self.num_experts,1,1)
|
||||
mask=torch.abs(1-eye)
|
||||
b,c,h,w=H.shape
|
||||
ninf = torch.zeros_like(eye)
|
||||
ninf[eye==1] = -float('inf')
|
||||
H_masked=H.view(b,c,1,h,w)*mask+ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes.
|
||||
max_excluding=torch.max(H_masked,dim=2)[0]
|
||||
|
||||
# load_loss and G are stored as local members to facilitate their use by hard routing regularization losses.
|
||||
# this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below.
|
||||
self.load_loss = torch.erf((wg - max_excluding)/wnoise)
|
||||
#self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so:
|
||||
self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream.
|
||||
|
||||
return self.G
|
||||
|
||||
# Retrieve the locally stored loss values and delete them from membership (so as to not waste memory)
|
||||
def reset(self):
|
||||
G, load = self.G, self.load_loss
|
||||
del self.G
|
||||
del self.load_loss
|
||||
return G, load
|
||||
|
||||
|
||||
# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses.
|
||||
class MixtureOfExpertsLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
||||
self.first_forward_encountered = False
|
||||
self.load_weight = opt['load_weight']
|
||||
self.importance_weight = opt['importance_weight']
|
||||
|
||||
def forward(self, net, state):
|
||||
if not self.first_forward_encountered:
|
||||
for m in net.modules():
|
||||
if isinstance(m, MixtureOfExperts2dRouter):
|
||||
self.routers.append(m)
|
||||
self.first_forward_encountered = True
|
||||
|
||||
l_importance = 0
|
||||
l_load = 0
|
||||
for r in self.routers:
|
||||
G, L = r.reset()
|
||||
l_importance += G.var().square()
|
||||
l_load += L.var().square()
|
||||
return l_importance * self.importance_weight + l_load * self.load_weight
|
||||
|
||||
|
||||
class SwitchTransformersLoadBalancer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.norm = SwitchNorm(8, accumulator_size=256)
|
||||
|
||||
def forward(self, x):
|
||||
self.soft = self.norm(nn.functional.softmax(x, dim=1))
|
||||
self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream.
|
||||
return self.hard
|
||||
|
||||
def reset(self):
|
||||
soft, hard = self.soft, self.hard
|
||||
del self.soft, self.hard
|
||||
return soft, hard
|
||||
|
||||
|
||||
class SwitchTransformersLoadBalancingLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
||||
self.first_forward_encountered = False
|
||||
|
||||
def forward(self, net, state):
|
||||
if not self.first_forward_encountered:
|
||||
for m in net.modules():
|
||||
if isinstance(m, SwitchTransformersLoadBalancer):
|
||||
self.routers.append(m)
|
||||
self.first_forward_encountered = True
|
||||
|
||||
loss = 0
|
||||
for r in self.routers:
|
||||
soft, hard = r.reset()
|
||||
N = hard.shape[1]
|
||||
h_mean = hard.mean(dim=[0,2,3])
|
||||
s_mean = soft.mean(dim=[0,2,3])
|
||||
loss += torch.dot(h_mean, s_mean) * N
|
||||
return loss
|
||||
|
||||
|
||||
class SwitchedConvHardRouting(nn.Module):
|
||||
def __init__(self,
|
||||
in_c,
|
||||
|
@ -120,8 +267,7 @@ class SwitchedConvHardRouting(nn.Module):
|
|||
dropout_rate=0.0,
|
||||
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
||||
coupler_mode: str = 'standard',
|
||||
coupler_dim_in: int = 0,
|
||||
switch_norm: bool = True):
|
||||
coupler_dim_in: int = 0):
|
||||
super().__init__()
|
||||
self.in_channels = in_c
|
||||
self.out_channels = out_c
|
||||
|
@ -130,14 +276,10 @@ class SwitchedConvHardRouting(nn.Module):
|
|||
self.has_bias = bias
|
||||
self.breadth = breadth
|
||||
self.dropout_rate = dropout_rate
|
||||
if switch_norm:
|
||||
self.switch_norm = SwitchNorm(breadth, accumulator_size=512)
|
||||
else:
|
||||
self.switch_norm = None
|
||||
|
||||
if include_coupler:
|
||||
if coupler_mode == 'standard':
|
||||
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1)
|
||||
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
|
||||
elif coupler_mode == 'lambda':
|
||||
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
||||
nn.BatchNorm2d(coupler_dim_in),
|
||||
|
@ -145,9 +287,11 @@ class SwitchedConvHardRouting(nn.Module):
|
|||
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
||||
nn.BatchNorm2d(breadth),
|
||||
nn.ReLU(),
|
||||
Conv2d(breadth, breadth, 1))
|
||||
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||
else:
|
||||
self.coupler = None
|
||||
#self.gate = MixtureOfExperts2dRouter(breadth)
|
||||
self.gate = SwitchTransformersLoadBalancer()
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
|
||||
if bias:
|
||||
|
@ -175,14 +319,10 @@ class SwitchedConvHardRouting(nn.Module):
|
|||
# If a coupler was specified, run that to convert selector into a softmax distribution.
|
||||
if self.coupler:
|
||||
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
|
||||
selector = input.detach()
|
||||
selector = F.softmax(self.coupler(selector), dim=1)
|
||||
selector = input
|
||||
selector = self.coupler(selector)
|
||||
assert selector is not None
|
||||
|
||||
# Perform normalization on the selector if applicable.
|
||||
if self.switch_norm:
|
||||
selector = self.switch_norm(selector)
|
||||
|
||||
# Apply dropout at the batch level per kernel.
|
||||
if self.training and self.dropout_rate > 0:
|
||||
b, c, h, w = selector.shape
|
||||
|
@ -192,11 +332,18 @@ class SwitchedConvHardRouting(nn.Module):
|
|||
drop = drop.logical_or(fix_blank)
|
||||
selector = drop * selector
|
||||
|
||||
selector = self.gate(selector)
|
||||
|
||||
# Debugging variables
|
||||
self.last_select = selector.detach().clone()
|
||||
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
|
||||
|
||||
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
|
||||
if False:
|
||||
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
|
||||
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
|
||||
else:
|
||||
# This composes the switching functionality using raw Torch, which basically consists of computing each of <breadth> convs separately and combining them.
|
||||
return SwitchedConvRoutingNormal(input, selector, self.weight, self.bias, self.stride)
|
||||
|
||||
|
||||
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
|
||||
|
@ -213,7 +360,11 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_
|
|||
continue
|
||||
if ignored:
|
||||
continue
|
||||
state_dict[f'{name}.weight'] = state_dict[f'{name}.weight'].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
||||
if name == '':
|
||||
key = 'weight'
|
||||
else:
|
||||
key = f'{name}.weight'
|
||||
state_dict[key] = state_dict[key].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
||||
return state_dict
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from utils.util import checkpoint, opt_get
|
|||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
|
||||
super().__init__()
|
||||
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters, dropout_rate=0.4)
|
||||
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_filters, dropout_rate=0.4)
|
||||
|
||||
def forward(self, x):
|
||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
|
@ -104,16 +104,16 @@ class Encoder(nn.Module):
|
|||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
||||
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel, dropout_rate=0.4),
|
||||
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel, dropout_rate=0.4),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
||||
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
||||
]
|
||||
|
||||
for i in range(n_res_block):
|
||||
|
@ -133,7 +133,7 @@ class Decoder(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.4)]
|
||||
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_channel, dropout_rate=0.4)]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
||||
|
|
|
@ -1,78 +1,49 @@
|
|||
import os.path as osp
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import utils
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from trainer.networks import define_F
|
||||
from utils import options as option
|
||||
import utils.util as util
|
||||
from data import create_dataset, create_dataloader
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from data.single_image_dataset import SingleImageDataset
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from models.vqvae.vqvae_no_conv_transpose import VQVAE
|
||||
|
||||
if __name__ == "__main__":
|
||||
bin_path = "f:\\binned"
|
||||
good_path = "f:\\good"
|
||||
os.makedirs(bin_path, exist_ok=True)
|
||||
os.makedirs(good_path, exist_ok=True)
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/generator_filter.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
opt['dist'] = False
|
||||
|
||||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items()
|
||||
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
model = VQVAE().cuda()
|
||||
model.load_state_dict(torch.load('../experiments/nvqvae_imgset.pth'))
|
||||
ds = SingleImageDataset({
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5'],
|
||||
'weights': [1],
|
||||
'target_size': 128,
|
||||
'force_multiple': 32,
|
||||
'scale': 1,
|
||||
'eval': False
|
||||
})
|
||||
dl = DataLoader(ds, batch_size=256, num_workers=1)
|
||||
|
||||
#### Create test dataset and dataloader
|
||||
test_loaders = []
|
||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
||||
test_set = create_dataset(dataset_opt)
|
||||
test_loader = create_dataloader(test_set, dataset_opt, opt=opt)
|
||||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
test_loaders.append(test_loader)
|
||||
means = []
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i, data in enumerate(tqdm(dl)):
|
||||
hq = data['hq'].cuda()
|
||||
gen = model(hq)[0]
|
||||
l2 = torch.mean(torch.square(hq - gen), dim=[1,2,3])
|
||||
for b in range(len(l2)):
|
||||
if l2[b] > .0004:
|
||||
shutil.copy(data['GT_path'][b], good_path)
|
||||
#else:
|
||||
# shutil.copy(data['GT_path'][b], bin_path)
|
||||
|
||||
model = ExtensibleTrainer(opt)
|
||||
utils.util.loaded_options = opt
|
||||
fea_loss = 0
|
||||
for test_loader in test_loaders:
|
||||
test_set_name = test_loader.dataset.opt['name']
|
||||
logger.info('\nTesting [{:s}]...'.format(test_set_name))
|
||||
test_start_time = time.time()
|
||||
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
|
||||
util.mkdir(dataset_dir)
|
||||
netF = define_F(which_model='vgg').to(model.env['device'])
|
||||
|
||||
tq = tqdm(test_loader)
|
||||
removed = 0
|
||||
means = []
|
||||
for data in tq:
|
||||
model.feed_data(data, need_GT=True)
|
||||
model.test()
|
||||
gen = model.eval_state['gen'][0].to(model.env['device'])
|
||||
feagen = netF(gen)
|
||||
feareal = netF(data['hq'].to(model.env['device']))
|
||||
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
|
||||
means.append(torch.mean(losses).item())
|
||||
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
|
||||
for i in range(losses.shape[0]):
|
||||
if losses[i] < 25000:
|
||||
os.remove(data['GT_path'][i])
|
||||
removed += 1
|
||||
#imname = osp.basename(data['GT_path'][i])
|
||||
#if losses[i] < 25000:
|
||||
# torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
|
||||
|
||||
print("Removed %i/%i images" % (removed, len(test_set)))
|
||||
#means.append(l2.cpu())
|
||||
#if i % 10 == 0:
|
||||
# print(torch.stack(means, dim=0).mean())
|
||||
|
|
|
@ -47,6 +47,12 @@ def create_loss(opt_loss, env):
|
|||
return RecurrentLoss(opt_loss, env)
|
||||
elif type == 'for_element':
|
||||
return ForElementLoss(opt_loss, env)
|
||||
elif type == 'mixture_of_experts':
|
||||
from models.switched_conv_hard_routing import MixtureOfExpertsLoss
|
||||
return MixtureOfExpertsLoss(opt_loss, env)
|
||||
elif type == 'switch_transformer_balance':
|
||||
from models.switched_conv_hard_routing import SwitchTransformersLoadBalancingLoss
|
||||
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user