Add switchnorm to gumbel_quantizer
This commit is contained in:
parent
ac57cdc794
commit
4d1a42e944
|
@ -86,7 +86,7 @@ class RouteTop1(torch.autograd.Function):
|
|||
|
||||
|
||||
"""
|
||||
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of
|
||||
SwitchNorm is meant to be applied against the Softmax output of a switching function across a large set of
|
||||
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
|
||||
of switch weights that are over-used and increasing the magnitude of under-used weights.
|
||||
|
||||
|
@ -154,7 +154,7 @@ class SwitchNorm(nn.Module):
|
|||
norm = torch.ones(self.group_size, device=self.accumulator.device)
|
||||
|
||||
norm = norm.view(1,-1)
|
||||
while len(x.shape) < len(norm.shape):
|
||||
while len(x.shape) > len(norm.shape):
|
||||
norm = norm.unsqueeze(-1)
|
||||
x = x / norm
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
|
||||
from models.switched_conv.switched_conv_hard_routing import SwitchNorm
|
||||
from utils.weight_scheduler import LinearDecayWeightScheduler
|
||||
|
||||
|
||||
|
@ -14,6 +15,7 @@ class GumbelQuantizer(nn.Module):
|
|||
self.straight_through = straight_through
|
||||
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
||||
self.step = 0
|
||||
self.norm = SwitchNorm(num_tokens)
|
||||
|
||||
def get_temperature(self, step):
|
||||
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
|
||||
|
@ -40,6 +42,7 @@ class GumbelQuantizer(nn.Module):
|
|||
h = h.permute(0,2,1)
|
||||
logits = self.to_logits(h)
|
||||
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through)
|
||||
logits = self.norm(logits)
|
||||
codes = logits.argmax(dim=1).flatten(1)
|
||||
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
|
||||
return sampled.permute(0,2,1), 0, codes
|
||||
|
|
Loading…
Reference in New Issue
Block a user