Move switched_conv logic around a bit
This commit is contained in:
parent
0dca36946f
commit
320edbaa3c
|
@ -9,11 +9,10 @@ import torchvision
|
||||||
from torchvision.models.resnet import Bottleneck
|
from torchvision.models.resnet import Bottleneck
|
||||||
|
|
||||||
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||||
from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2
|
|
||||||
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, sequential_checkpoint, opt_get
|
from utils.util import checkpoint, sequential_checkpoint, opt_get
|
||||||
from models.switched_conv import SwitchedConv
|
from models.switched_conv.switched_conv import SwitchedConv
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock(nn.Module):
|
class ResidualDenseBlock(nn.Module):
|
||||||
|
|
0
codes/models/switched_conv/__init__.py
Normal file
0
codes/models/switched_conv/__init__.py
Normal file
128
codes/models/switched_conv/mixture_of_experts.py
Normal file
128
codes/models/switched_conv/mixture_of_experts.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
# Contains implementations from the Mixture of Experts paper and Switch Transformers
|
||||||
|
|
||||||
|
|
||||||
|
# Implements KeepTopK where k=1 from mixture of experts paper.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from models.switched_conv.switched_conv_hard_routing import RouteTop1
|
||||||
|
from trainer.losses import ConfigurableLoss
|
||||||
|
|
||||||
|
|
||||||
|
class KeepTop1(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input):
|
||||||
|
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
|
||||||
|
input[mask != 1] = -float('inf')
|
||||||
|
ctx.save_for_backward(mask)
|
||||||
|
return input
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
import pydevd
|
||||||
|
pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||||
|
mask = ctx.saved_tensors
|
||||||
|
grad_input = grad.clone()
|
||||||
|
grad_input[mask != 1] = 0
|
||||||
|
return grad_input
|
||||||
|
|
||||||
|
class MixtureOfExperts2dRouter(nn.Module):
|
||||||
|
def __init__(self, num_experts):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.wnoise = nn.Parameter(torch.zeros(1, num_experts, 1, 1))
|
||||||
|
self.wg = nn.Parameter(torch.zeros(1, num_experts, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
wg = x * self.wg
|
||||||
|
wnoise = nn.functional.softplus(x * self.wnoise)
|
||||||
|
H = wg + torch.randn_like(x) * wnoise
|
||||||
|
|
||||||
|
# Produce the load-balancing loss.
|
||||||
|
eye = torch.eye(self.num_experts, device=x.device).view(1, self.num_experts, self.num_experts, 1, 1)
|
||||||
|
mask = torch.abs(1 - eye)
|
||||||
|
b, c, h, w = H.shape
|
||||||
|
ninf = torch.zeros_like(eye)
|
||||||
|
ninf[eye == 1] = -float('inf')
|
||||||
|
H_masked = H.view(b, c, 1, h,
|
||||||
|
w) * mask + ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes.
|
||||||
|
max_excluding = torch.max(H_masked, dim=2)[0]
|
||||||
|
|
||||||
|
# load_loss and G are stored as local members to facilitate their use by hard routing regularization losses.
|
||||||
|
# this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below.
|
||||||
|
self.load_loss = torch.erf((wg - max_excluding) / wnoise)
|
||||||
|
# self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so:
|
||||||
|
self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream.
|
||||||
|
|
||||||
|
return self.G
|
||||||
|
|
||||||
|
# Retrieve the locally stored loss values and delete them from membership (so as to not waste memory)
|
||||||
|
def reset(self):
|
||||||
|
G, load = self.G, self.load_loss
|
||||||
|
del self.G
|
||||||
|
del self.load_loss
|
||||||
|
return G, load
|
||||||
|
|
||||||
|
|
||||||
|
# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses.
|
||||||
|
class MixtureOfExpertsLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
||||||
|
self.first_forward_encountered = False
|
||||||
|
self.load_weight = opt['load_weight']
|
||||||
|
self.importance_weight = opt['importance_weight']
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
if not self.first_forward_encountered:
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, MixtureOfExperts2dRouter):
|
||||||
|
self.routers.append(m)
|
||||||
|
self.first_forward_encountered = True
|
||||||
|
|
||||||
|
l_importance = 0
|
||||||
|
l_load = 0
|
||||||
|
for r in self.routers:
|
||||||
|
G, L = r.reset()
|
||||||
|
l_importance += G.var().square()
|
||||||
|
l_load += L.var().square()
|
||||||
|
return l_importance * self.importance_weight + l_load * self.load_weight
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchTransformersLoadBalancer(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = SwitchNorm(8, accumulator_size=256)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.soft = self.norm(nn.functional.softmax(x, dim=1))
|
||||||
|
self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream.
|
||||||
|
return self.hard
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
soft, hard = self.soft, self.hard
|
||||||
|
del self.soft, self.hard
|
||||||
|
return soft, hard
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchTransformersLoadBalancingLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
||||||
|
self.first_forward_encountered = False
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
if not self.first_forward_encountered:
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, SwitchTransformersLoadBalancer):
|
||||||
|
self.routers.append(m)
|
||||||
|
self.first_forward_encountered = True
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
for r in self.routers:
|
||||||
|
soft, hard = r.reset()
|
||||||
|
N = hard.shape[1]
|
||||||
|
h_mean = hard.mean(dim=[0,2,3])
|
||||||
|
s_mean = soft.mean(dim=[0,2,3])
|
||||||
|
loss += torch.dot(h_mean, s_mean) * N
|
||||||
|
return loss
|
|
@ -42,25 +42,6 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function):
|
||||||
return grad, grad_sel, grad_w, grad_b, None
|
return grad, grad_sel, grad_w, grad_b, None
|
||||||
|
|
||||||
|
|
||||||
# Implements KeepTopK where k=1 from mixture of experts paper.
|
|
||||||
class KeepTop1(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input):
|
|
||||||
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
|
|
||||||
input[mask != 1] = -float('inf')
|
|
||||||
ctx.save_for_backward(mask)
|
|
||||||
return input
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad):
|
|
||||||
import pydevd
|
|
||||||
pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
|
||||||
mask = ctx.saved_tensors
|
|
||||||
grad_input = grad.clone()
|
|
||||||
grad_input[mask != 1] = 0
|
|
||||||
return grad_input
|
|
||||||
|
|
||||||
|
|
||||||
class RouteTop1(torch.autograd.Function):
|
class RouteTop1(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
|
@ -155,105 +136,15 @@ class SwitchNorm(nn.Module):
|
||||||
return x / x.sum(dim=1, keepdim=True)
|
return x / x.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
class MixtureOfExperts2dRouter(nn.Module):
|
class HardRoutingGate(nn.Module):
|
||||||
def __init__(self, num_experts):
|
def __init__(self, breadth):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = num_experts
|
self.norm = SwitchNorm(breadth, accumulator_size=256)
|
||||||
self.wnoise = nn.Parameter(torch.zeros(1,num_experts,1,1))
|
|
||||||
self.wg = nn.Parameter(torch.zeros(1,num_experts,1,1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
wg = x * self.wg
|
soft = self.norm(nn.functional.softmax(x, dim=1))
|
||||||
wnoise = nn.functional.softplus(x * self.wnoise)
|
hard = RouteTop1.apply(soft) # This variant can route gradients downstream.
|
||||||
H = wg + torch.randn_like(x) * wnoise
|
return hard
|
||||||
|
|
||||||
# Produce the load-balancing loss.
|
|
||||||
eye = torch.eye(self.num_experts, device=x.device).view(1,self.num_experts,self.num_experts,1,1)
|
|
||||||
mask=torch.abs(1-eye)
|
|
||||||
b,c,h,w=H.shape
|
|
||||||
ninf = torch.zeros_like(eye)
|
|
||||||
ninf[eye==1] = -float('inf')
|
|
||||||
H_masked=H.view(b,c,1,h,w)*mask+ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes.
|
|
||||||
max_excluding=torch.max(H_masked,dim=2)[0]
|
|
||||||
|
|
||||||
# load_loss and G are stored as local members to facilitate their use by hard routing regularization losses.
|
|
||||||
# this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below.
|
|
||||||
self.load_loss = torch.erf((wg - max_excluding)/wnoise)
|
|
||||||
#self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so:
|
|
||||||
self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream.
|
|
||||||
|
|
||||||
return self.G
|
|
||||||
|
|
||||||
# Retrieve the locally stored loss values and delete them from membership (so as to not waste memory)
|
|
||||||
def reset(self):
|
|
||||||
G, load = self.G, self.load_loss
|
|
||||||
del self.G
|
|
||||||
del self.load_loss
|
|
||||||
return G, load
|
|
||||||
|
|
||||||
|
|
||||||
# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses.
|
|
||||||
class MixtureOfExpertsLoss(ConfigurableLoss):
|
|
||||||
def __init__(self, opt, env):
|
|
||||||
super().__init__(opt, env)
|
|
||||||
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
|
||||||
self.first_forward_encountered = False
|
|
||||||
self.load_weight = opt['load_weight']
|
|
||||||
self.importance_weight = opt['importance_weight']
|
|
||||||
|
|
||||||
def forward(self, net, state):
|
|
||||||
if not self.first_forward_encountered:
|
|
||||||
for m in net.modules():
|
|
||||||
if isinstance(m, MixtureOfExperts2dRouter):
|
|
||||||
self.routers.append(m)
|
|
||||||
self.first_forward_encountered = True
|
|
||||||
|
|
||||||
l_importance = 0
|
|
||||||
l_load = 0
|
|
||||||
for r in self.routers:
|
|
||||||
G, L = r.reset()
|
|
||||||
l_importance += G.var().square()
|
|
||||||
l_load += L.var().square()
|
|
||||||
return l_importance * self.importance_weight + l_load * self.load_weight
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchTransformersLoadBalancer(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.norm = SwitchNorm(8, accumulator_size=256)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
self.soft = self.norm(nn.functional.softmax(x, dim=1))
|
|
||||||
self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream.
|
|
||||||
return self.hard
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
soft, hard = self.soft, self.hard
|
|
||||||
del self.soft, self.hard
|
|
||||||
return soft, hard
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchTransformersLoadBalancingLoss(ConfigurableLoss):
|
|
||||||
def __init__(self, opt, env):
|
|
||||||
super().__init__(opt, env)
|
|
||||||
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
|
||||||
self.first_forward_encountered = False
|
|
||||||
|
|
||||||
def forward(self, net, state):
|
|
||||||
if not self.first_forward_encountered:
|
|
||||||
for m in net.modules():
|
|
||||||
if isinstance(m, SwitchTransformersLoadBalancer):
|
|
||||||
self.routers.append(m)
|
|
||||||
self.first_forward_encountered = True
|
|
||||||
|
|
||||||
loss = 0
|
|
||||||
for r in self.routers:
|
|
||||||
soft, hard = r.reset()
|
|
||||||
N = hard.shape[1]
|
|
||||||
h_mean = hard.mean(dim=[0,2,3])
|
|
||||||
s_mean = soft.mean(dim=[0,2,3])
|
|
||||||
loss += torch.dot(h_mean, s_mean) * N
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchedConvHardRouting(nn.Module):
|
class SwitchedConvHardRouting(nn.Module):
|
||||||
|
@ -290,8 +181,7 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
Conv2d(breadth, breadth, 1, stride=self.stride))
|
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||||
else:
|
else:
|
||||||
self.coupler = None
|
self.coupler = None
|
||||||
#self.gate = MixtureOfExperts2dRouter(breadth)
|
self.gate = HardRoutingGate(breadth)
|
||||||
self.gate = SwitchTransformersLoadBalancer()
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
|
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
|
||||||
if bias:
|
if bias:
|
|
@ -7,7 +7,7 @@ from torch.nn import functional as F
|
||||||
|
|
||||||
import torch.distributed as distributed
|
import torch.distributed as distributed
|
||||||
|
|
||||||
from models.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
||||||
convert_conv_net_state_dict_to_switched_conv
|
convert_conv_net_state_dict_to_switched_conv
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, opt_get
|
from utils.util import checkpoint, opt_get
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import functional as F
|
||||||
|
|
||||||
import torch.distributed as distributed
|
import torch.distributed as distributed
|
||||||
|
|
||||||
from models.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv
|
from models.switched_conv.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, opt_get
|
from utils.util import checkpoint, opt_get
|
||||||
|
|
||||||
|
|
|
@ -48,10 +48,10 @@ def create_loss(opt_loss, env):
|
||||||
elif type == 'for_element':
|
elif type == 'for_element':
|
||||||
return ForElementLoss(opt_loss, env)
|
return ForElementLoss(opt_loss, env)
|
||||||
elif type == 'mixture_of_experts':
|
elif type == 'mixture_of_experts':
|
||||||
from models.switched_conv_hard_routing import MixtureOfExpertsLoss
|
from models.switched_conv.mixture_of_experts import MixtureOfExpertsLoss
|
||||||
return MixtureOfExpertsLoss(opt_loss, env)
|
return MixtureOfExpertsLoss(opt_loss, env)
|
||||||
elif type == 'switch_transformer_balance':
|
elif type == 'switch_transformer_balance':
|
||||||
from models.switched_conv_hard_routing import SwitchTransformersLoadBalancingLoss
|
from models.switched_conv.mixture_of_experts import SwitchTransformersLoadBalancingLoss
|
||||||
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
Loading…
Reference in New Issue
Block a user