Add switchnorm to gumbel_quantizer

This commit is contained in:
James Betker 2021-09-24 18:49:25 -06:00
parent ac57cdc794
commit 4d1a42e944
2 changed files with 5 additions and 2 deletions

View File

@ -86,7 +86,7 @@ class RouteTop1(torch.autograd.Function):
"""
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of
SwitchNorm is meant to be applied against the Softmax output of a switching function across a large set of
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
of switch weights that are over-used and increasing the magnitude of under-used weights.
@ -154,7 +154,7 @@ class SwitchNorm(nn.Module):
norm = torch.ones(self.group_size, device=self.accumulator.device)
norm = norm.view(1,-1)
while len(x.shape) < len(norm.shape):
while len(x.shape) > len(norm.shape):
norm = norm.unsqueeze(-1)
x = x / norm

View File

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from models.switched_conv.switched_conv_hard_routing import SwitchNorm
from utils.weight_scheduler import LinearDecayWeightScheduler
@ -14,6 +15,7 @@ class GumbelQuantizer(nn.Module):
self.straight_through = straight_through
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
self.step = 0
self.norm = SwitchNorm(num_tokens)
def get_temperature(self, step):
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
@ -40,6 +42,7 @@ class GumbelQuantizer(nn.Module):
h = h.permute(0,2,1)
logits = self.to_logits(h)
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through)
logits = self.norm(logits)
codes = logits.argmax(dim=1).flatten(1)
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
return sampled.permute(0,2,1), 0, codes