DL-Art-School/dlas/models/vqvae/gumbel_quantizer.py

64 lines
2.2 KiB
Python
Raw Normal View History

2021-09-24 05:32:03 +00:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
import dlas.torch_intermediary as ml
from dlas.utils.weight_scheduler import LinearDecayWeightScheduler
2021-09-24 05:32:03 +00:00
class GumbelQuantizer(nn.Module):
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False):
2021-09-24 05:32:03 +00:00
super().__init__()
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
2023-02-22 23:07:05 +00:00
# nn.Embedding
self.codebook = ml.Embedding(num_tokens, codebook_dim)
2021-09-24 05:32:03 +00:00
self.straight_through = straight_through
self.temperature_scheduler = LinearDecayWeightScheduler(
10, 5000, .9, 2000)
self.step = 0
2021-09-25 00:49:25 +00:00
self.norm = SwitchNorm(num_tokens)
def get_temperature(self, step):
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
return self.temperature_scheduler.get_weight_for_step(step)
2021-09-24 05:32:03 +00:00
def embed_code(self, codes):
return self.codebook(codes)
def gumbel_softmax(self, logits, tau, dim, hard):
gumbels = torch.rand_like(logits)
gumbels = -torch.log(-torch.log(gumbels + 1e-8) + 1e-8)
logits = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = F.softmax(logits, dim=dim)
if hard:
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(
logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
ret = y_soft
return ret
2021-09-24 05:32:03 +00:00
def forward(self, h):
h = h.permute(0, 2, 1)
2021-09-24 05:32:03 +00:00
logits = self.to_logits(h)
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(
self.step), dim=1, hard=self.straight_through)
2021-09-25 00:49:25 +00:00
logits = self.norm(logits)
2021-09-24 05:32:03 +00:00
codes = logits.argmax(dim=1).flatten(1)
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
return sampled.permute(0, 2, 1), 0, codes
if __name__ == '__main__':
j = torch.randn(8, 40, 1024)
m = GumbelQuantizer(1024, 1024, 4096)
m2 = DiscreteDecoder(1024, (512, 256), 2)
l = m2(m(j)[0].permute(0, 2, 1))
mean = 0
for ls in l:
mean = mean + ls.mean()
mean.backward()