Add switchnorm to gumbel_quantizer

This commit is contained in:
James Betker 2021-09-24 18:49:25 -06:00
parent ac57cdc794
commit 4d1a42e944
2 changed files with 5 additions and 2 deletions

View File

@ -86,7 +86,7 @@ class RouteTop1(torch.autograd.Function):
""" """
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of SwitchNorm is meant to be applied against the Softmax output of a switching function across a large set of
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
of switch weights that are over-used and increasing the magnitude of under-used weights. of switch weights that are over-used and increasing the magnitude of under-used weights.
@ -154,7 +154,7 @@ class SwitchNorm(nn.Module):
norm = torch.ones(self.group_size, device=self.accumulator.device) norm = torch.ones(self.group_size, device=self.accumulator.device)
norm = norm.view(1,-1) norm = norm.view(1,-1)
while len(x.shape) < len(norm.shape): while len(x.shape) > len(norm.shape):
norm = norm.unsqueeze(-1) norm = norm.unsqueeze(-1)
x = x / norm x = x / norm

View File

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import einsum from torch import einsum
from models.switched_conv.switched_conv_hard_routing import SwitchNorm
from utils.weight_scheduler import LinearDecayWeightScheduler from utils.weight_scheduler import LinearDecayWeightScheduler
@ -14,6 +15,7 @@ class GumbelQuantizer(nn.Module):
self.straight_through = straight_through self.straight_through = straight_through
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000) self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
self.step = 0 self.step = 0
self.norm = SwitchNorm(num_tokens)
def get_temperature(self, step): def get_temperature(self, step):
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN??? self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
@ -40,6 +42,7 @@ class GumbelQuantizer(nn.Module):
h = h.permute(0,2,1) h = h.permute(0,2,1)
logits = self.to_logits(h) logits = self.to_logits(h)
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through) logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through)
logits = self.norm(logits)
codes = logits.argmax(dim=1).flatten(1) codes = logits.argmax(dim=1).flatten(1)
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight) sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
return sampled.permute(0,2,1), 0, codes return sampled.permute(0,2,1), 0, codes