Add switchnorm to gumbel_quantizer
This commit is contained in:
parent
ac57cdc794
commit
4d1a42e944
|
@ -86,7 +86,7 @@ class RouteTop1(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of
|
SwitchNorm is meant to be applied against the Softmax output of a switching function across a large set of
|
||||||
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
|
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
|
||||||
of switch weights that are over-used and increasing the magnitude of under-used weights.
|
of switch weights that are over-used and increasing the magnitude of under-used weights.
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class SwitchNorm(nn.Module):
|
||||||
norm = torch.ones(self.group_size, device=self.accumulator.device)
|
norm = torch.ones(self.group_size, device=self.accumulator.device)
|
||||||
|
|
||||||
norm = norm.view(1,-1)
|
norm = norm.view(1,-1)
|
||||||
while len(x.shape) < len(norm.shape):
|
while len(x.shape) > len(norm.shape):
|
||||||
norm = norm.unsqueeze(-1)
|
norm = norm.unsqueeze(-1)
|
||||||
x = x / norm
|
x = x / norm
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
|
from models.switched_conv.switched_conv_hard_routing import SwitchNorm
|
||||||
from utils.weight_scheduler import LinearDecayWeightScheduler
|
from utils.weight_scheduler import LinearDecayWeightScheduler
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +15,7 @@ class GumbelQuantizer(nn.Module):
|
||||||
self.straight_through = straight_through
|
self.straight_through = straight_through
|
||||||
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
self.norm = SwitchNorm(num_tokens)
|
||||||
|
|
||||||
def get_temperature(self, step):
|
def get_temperature(self, step):
|
||||||
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
|
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
|
||||||
|
@ -40,6 +42,7 @@ class GumbelQuantizer(nn.Module):
|
||||||
h = h.permute(0,2,1)
|
h = h.permute(0,2,1)
|
||||||
logits = self.to_logits(h)
|
logits = self.to_logits(h)
|
||||||
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through)
|
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through)
|
||||||
|
logits = self.norm(logits)
|
||||||
codes = logits.argmax(dim=1).flatten(1)
|
codes = logits.argmax(dim=1).flatten(1)
|
||||||
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
|
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
|
||||||
return sampled.permute(0,2,1), 0, codes
|
return sampled.permute(0,2,1), 0, codes
|
||||||
|
|
Loading…
Reference in New Issue
Block a user