diff --git a/codes/models/RRDBNet_arch.py b/codes/models/RRDBNet_arch.py index faeec8b7..a70572f0 100644 --- a/codes/models/RRDBNet_arch.py +++ b/codes/models/RRDBNet_arch.py @@ -9,11 +9,10 @@ import torchvision from torchvision.models.resnet import Bottleneck from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu -from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2 from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3 from trainer.networks import register_model from utils.util import checkpoint, sequential_checkpoint, opt_get -from models.switched_conv import SwitchedConv +from models.switched_conv.switched_conv import SwitchedConv class ResidualDenseBlock(nn.Module): diff --git a/codes/models/switched_conv/__init__.py b/codes/models/switched_conv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/switched_conv/mixture_of_experts.py b/codes/models/switched_conv/mixture_of_experts.py new file mode 100644 index 00000000..bf888a4f --- /dev/null +++ b/codes/models/switched_conv/mixture_of_experts.py @@ -0,0 +1,128 @@ +# Contains implementations from the Mixture of Experts paper and Switch Transformers + + +# Implements KeepTopK where k=1 from mixture of experts paper. +import torch +import torch.nn as nn + +from models.switched_conv.switched_conv_hard_routing import RouteTop1 +from trainer.losses import ConfigurableLoss + + +class KeepTop1(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2) + input[mask != 1] = -float('inf') + ctx.save_for_backward(mask) + return input + + @staticmethod + def backward(ctx, grad): + import pydevd + pydevd.settrace(suspend=False, trace_only_current_thread=True) + mask = ctx.saved_tensors + grad_input = grad.clone() + grad_input[mask != 1] = 0 + return grad_input + +class MixtureOfExperts2dRouter(nn.Module): + def __init__(self, num_experts): + super().__init__() + self.num_experts = num_experts + self.wnoise = nn.Parameter(torch.zeros(1, num_experts, 1, 1)) + self.wg = nn.Parameter(torch.zeros(1, num_experts, 1, 1)) + + def forward(self, x): + wg = x * self.wg + wnoise = nn.functional.softplus(x * self.wnoise) + H = wg + torch.randn_like(x) * wnoise + + # Produce the load-balancing loss. + eye = torch.eye(self.num_experts, device=x.device).view(1, self.num_experts, self.num_experts, 1, 1) + mask = torch.abs(1 - eye) + b, c, h, w = H.shape + ninf = torch.zeros_like(eye) + ninf[eye == 1] = -float('inf') + H_masked = H.view(b, c, 1, h, + w) * mask + ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes. + max_excluding = torch.max(H_masked, dim=2)[0] + + # load_loss and G are stored as local members to facilitate their use by hard routing regularization losses. + # this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below. + self.load_loss = torch.erf((wg - max_excluding) / wnoise) + # self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so: + self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream. + + return self.G + + # Retrieve the locally stored loss values and delete them from membership (so as to not waste memory) + def reset(self): + G, load = self.G, self.load_loss + del self.G + del self.load_loss + return G, load + + +# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses. +class MixtureOfExpertsLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.routers = [] # This is filled in during the first forward() pass and cached from there. + self.first_forward_encountered = False + self.load_weight = opt['load_weight'] + self.importance_weight = opt['importance_weight'] + + def forward(self, net, state): + if not self.first_forward_encountered: + for m in net.modules(): + if isinstance(m, MixtureOfExperts2dRouter): + self.routers.append(m) + self.first_forward_encountered = True + + l_importance = 0 + l_load = 0 + for r in self.routers: + G, L = r.reset() + l_importance += G.var().square() + l_load += L.var().square() + return l_importance * self.importance_weight + l_load * self.load_weight + + +class SwitchTransformersLoadBalancer(nn.Module): + def __init__(self): + super().__init__() + self.norm = SwitchNorm(8, accumulator_size=256) + + def forward(self, x): + self.soft = self.norm(nn.functional.softmax(x, dim=1)) + self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream. + return self.hard + + def reset(self): + soft, hard = self.soft, self.hard + del self.soft, self.hard + return soft, hard + + +class SwitchTransformersLoadBalancingLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.routers = [] # This is filled in during the first forward() pass and cached from there. + self.first_forward_encountered = False + + def forward(self, net, state): + if not self.first_forward_encountered: + for m in net.modules(): + if isinstance(m, SwitchTransformersLoadBalancer): + self.routers.append(m) + self.first_forward_encountered = True + + loss = 0 + for r in self.routers: + soft, hard = r.reset() + N = hard.shape[1] + h_mean = hard.mean(dim=[0,2,3]) + s_mean = soft.mean(dim=[0,2,3]) + loss += torch.dot(h_mean, s_mean) * N + return loss \ No newline at end of file diff --git a/codes/models/switched_conv.py b/codes/models/switched_conv/switched_conv.py similarity index 100% rename from codes/models/switched_conv.py rename to codes/models/switched_conv/switched_conv.py diff --git a/codes/models/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py similarity index 71% rename from codes/models/switched_conv_hard_routing.py rename to codes/models/switched_conv/switched_conv_hard_routing.py index 2e7ea9a6..4cc9d03b 100644 --- a/codes/models/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -42,25 +42,6 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function): return grad, grad_sel, grad_w, grad_b, None -# Implements KeepTopK where k=1 from mixture of experts paper. -class KeepTop1(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2) - input[mask != 1] = -float('inf') - ctx.save_for_backward(mask) - return input - - @staticmethod - def backward(ctx, grad): - import pydevd - pydevd.settrace(suspend=False, trace_only_current_thread=True) - mask = ctx.saved_tensors - grad_input = grad.clone() - grad_input[mask != 1] = 0 - return grad_input - - class RouteTop1(torch.autograd.Function): @staticmethod def forward(ctx, input): @@ -155,105 +136,15 @@ class SwitchNorm(nn.Module): return x / x.sum(dim=1, keepdim=True) -class MixtureOfExperts2dRouter(nn.Module): - def __init__(self, num_experts): +class HardRoutingGate(nn.Module): + def __init__(self, breadth): super().__init__() - self.num_experts = num_experts - self.wnoise = nn.Parameter(torch.zeros(1,num_experts,1,1)) - self.wg = nn.Parameter(torch.zeros(1,num_experts,1,1)) + self.norm = SwitchNorm(breadth, accumulator_size=256) def forward(self, x): - wg = x * self.wg - wnoise = nn.functional.softplus(x * self.wnoise) - H = wg + torch.randn_like(x) * wnoise - - # Produce the load-balancing loss. - eye = torch.eye(self.num_experts, device=x.device).view(1,self.num_experts,self.num_experts,1,1) - mask=torch.abs(1-eye) - b,c,h,w=H.shape - ninf = torch.zeros_like(eye) - ninf[eye==1] = -float('inf') - H_masked=H.view(b,c,1,h,w)*mask+ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes. - max_excluding=torch.max(H_masked,dim=2)[0] - - # load_loss and G are stored as local members to facilitate their use by hard routing regularization losses. - # this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below. - self.load_loss = torch.erf((wg - max_excluding)/wnoise) - #self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so: - self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream. - - return self.G - - # Retrieve the locally stored loss values and delete them from membership (so as to not waste memory) - def reset(self): - G, load = self.G, self.load_loss - del self.G - del self.load_loss - return G, load - - -# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses. -class MixtureOfExpertsLoss(ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.routers = [] # This is filled in during the first forward() pass and cached from there. - self.first_forward_encountered = False - self.load_weight = opt['load_weight'] - self.importance_weight = opt['importance_weight'] - - def forward(self, net, state): - if not self.first_forward_encountered: - for m in net.modules(): - if isinstance(m, MixtureOfExperts2dRouter): - self.routers.append(m) - self.first_forward_encountered = True - - l_importance = 0 - l_load = 0 - for r in self.routers: - G, L = r.reset() - l_importance += G.var().square() - l_load += L.var().square() - return l_importance * self.importance_weight + l_load * self.load_weight - - -class SwitchTransformersLoadBalancer(nn.Module): - def __init__(self): - super().__init__() - self.norm = SwitchNorm(8, accumulator_size=256) - - def forward(self, x): - self.soft = self.norm(nn.functional.softmax(x, dim=1)) - self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream. - return self.hard - - def reset(self): - soft, hard = self.soft, self.hard - del self.soft, self.hard - return soft, hard - - -class SwitchTransformersLoadBalancingLoss(ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.routers = [] # This is filled in during the first forward() pass and cached from there. - self.first_forward_encountered = False - - def forward(self, net, state): - if not self.first_forward_encountered: - for m in net.modules(): - if isinstance(m, SwitchTransformersLoadBalancer): - self.routers.append(m) - self.first_forward_encountered = True - - loss = 0 - for r in self.routers: - soft, hard = r.reset() - N = hard.shape[1] - h_mean = hard.mean(dim=[0,2,3]) - s_mean = soft.mean(dim=[0,2,3]) - loss += torch.dot(h_mean, s_mean) * N - return loss + soft = self.norm(nn.functional.softmax(x, dim=1)) + hard = RouteTop1.apply(soft) # This variant can route gradients downstream. + return hard class SwitchedConvHardRouting(nn.Module): @@ -290,8 +181,7 @@ class SwitchedConvHardRouting(nn.Module): Conv2d(breadth, breadth, 1, stride=self.stride)) else: self.coupler = None - #self.gate = MixtureOfExperts2dRouter(breadth) - self.gate = SwitchTransformersLoadBalancer() + self.gate = HardRoutingGate(breadth) self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz)) if bias: diff --git a/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py b/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py index 9c29648b..6e3fd071 100644 --- a/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py +++ b/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py @@ -7,7 +7,7 @@ from torch.nn import functional as F import torch.distributed as distributed -from models.switched_conv_hard_routing import SwitchedConvHardRouting, \ +from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \ convert_conv_net_state_dict_to_switched_conv from trainer.networks import register_model from utils.util import checkpoint, opt_get diff --git a/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py index 84c5170a..6a9f380a 100644 --- a/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py +++ b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py @@ -7,7 +7,7 @@ from torch.nn import functional as F import torch.distributed as distributed -from models.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv +from models.switched_conv.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv from trainer.networks import register_model from utils.util import checkpoint, opt_get diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index 33aae699..37e916cf 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -48,10 +48,10 @@ def create_loss(opt_loss, env): elif type == 'for_element': return ForElementLoss(opt_loss, env) elif type == 'mixture_of_experts': - from models.switched_conv_hard_routing import MixtureOfExpertsLoss + from models.switched_conv.mixture_of_experts import MixtureOfExpertsLoss return MixtureOfExpertsLoss(opt_loss, env) elif type == 'switch_transformer_balance': - from models.switched_conv_hard_routing import SwitchTransformersLoadBalancingLoss + from models.switched_conv.mixture_of_experts import SwitchTransformersLoadBalancingLoss return SwitchTransformersLoadBalancingLoss(opt_loss, env) else: raise NotImplementedError