From 4d1a42e944c2959ce7eb5005d0c07eaf55ae9d1c Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 24 Sep 2021 18:49:25 -0600 Subject: [PATCH] Add switchnorm to gumbel_quantizer --- codes/models/switched_conv/switched_conv_hard_routing.py | 4 ++-- codes/models/vqvae/gumbel_quantizer.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index daad9022..28342f93 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -86,7 +86,7 @@ class RouteTop1(torch.autograd.Function): """ -SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of +SwitchNorm is meant to be applied against the Softmax output of a switching function across a large set of switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude of switch weights that are over-used and increasing the magnitude of under-used weights. @@ -154,7 +154,7 @@ class SwitchNorm(nn.Module): norm = torch.ones(self.group_size, device=self.accumulator.device) norm = norm.view(1,-1) - while len(x.shape) < len(norm.shape): + while len(x.shape) > len(norm.shape): norm = norm.unsqueeze(-1) x = x / norm diff --git a/codes/models/vqvae/gumbel_quantizer.py b/codes/models/vqvae/gumbel_quantizer.py index 1975ac19..ae90316c 100644 --- a/codes/models/vqvae/gumbel_quantizer.py +++ b/codes/models/vqvae/gumbel_quantizer.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import einsum +from models.switched_conv.switched_conv_hard_routing import SwitchNorm from utils.weight_scheduler import LinearDecayWeightScheduler @@ -14,6 +15,7 @@ class GumbelQuantizer(nn.Module): self.straight_through = straight_through self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000) self.step = 0 + self.norm = SwitchNorm(num_tokens) def get_temperature(self, step): self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN??? @@ -40,6 +42,7 @@ class GumbelQuantizer(nn.Module): h = h.permute(0,2,1) logits = self.to_logits(h) logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through) + logits = self.norm(logits) codes = logits.argmax(dim=1).flatten(1) sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight) return sampled.permute(0,2,1), 0, codes