From 0799d95af5e0f6a2b3b9d313bfe8f7b00fbea6a0 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 6 Aug 2021 14:06:26 -0600 Subject: [PATCH] Use quantizer from rosinality/vqvae with openai dvae --- codes/models/gpt_voice/lucidrains_dvae.py | 44 ++++++----------------- 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 90f32779..16e6b148 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from einops import rearrange from torch import einsum +from models.vqvae.vqvae import Quantize from trainer.networks import register_model from utils.util import opt_get @@ -51,9 +52,6 @@ class DiscreteVAE(nn.Module): hidden_dim = 64, channels = 3, smooth_l1_loss = False, - starting_temperature = 0.5, - temperature_annealing_rate = 0, - min_temperature = .5, straight_through = False, normalization = None, # ((0.5,) * 3, (0.5,) * 3), record_codes = False, @@ -64,13 +62,9 @@ class DiscreteVAE(nn.Module): self.num_tokens = num_tokens self.num_layers = num_layers - self.starting_temperature = starting_temperature - self.current_temperature = starting_temperature self.straight_through = straight_through - self.codebook = nn.Embedding(num_tokens, codebook_dim) + self.codebook = Quantize(num_tokens, codebook_dim) self.positional_dims = positional_dims - self.temperature_annealing_rate = temperature_annealing_rate - self.min_temperature = min_temperature assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. if positional_dims == 2: @@ -130,14 +124,9 @@ class DiscreteVAE(nn.Module): images.sub_(means).div_(stds) return images - def update_for_step(self, step, __): - # Run the annealing schedule - if self.temperature_annealing_rate != 0: - self.current_temperature = max(self.starting_temperature * math.exp(-self.temperature_annealing_rate * step), self.min_temperature) - def get_debug_values(self, step, __): # Report annealing schedule - return {'current_annealing_temperature': self.current_temperature, 'histogram_codes': self.codes} + return {'histogram_codes': self.codes} @torch.no_grad() @eval_decorator @@ -150,7 +139,7 @@ class DiscreteVAE(nn.Module): self, img_seq ): - image_embeds = self.codebook(img_seq) + image_embeds = self.codebook.embed_code(img_seq) b, n, d = image_embeds.shape kwargs = {} @@ -168,31 +157,18 @@ class DiscreteVAE(nn.Module): self, img ): - device, num_tokens = img.device, self.num_tokens img = self.norm(img) - logits = self.encoder(img) - soft_one_hot = F.gumbel_softmax(logits, tau = self.current_temperature, dim = 1, hard = self.straight_through) - - if self.positional_dims == 1: - arrange = 'b n s, n d -> b d s' - else: - arrange = 'b n h w, n d -> b d h w' - sampled = einsum(arrange, soft_one_hot, self.codebook.weight) + logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + sampled, commitment_loss, codes = self.codebook(logits) + sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) out = self.decoder(sampled) # reconstruction loss recon_loss = self.loss_fn(img, out) - # kl divergence - arrange = 'b n h w -> b (h w) n' if self.positional_dims == 2 else 'b n s -> b s n' - logits = rearrange(logits, arrange) - log_qy = F.log_softmax(logits, dim = -1) - log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device)) - kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target = True) - # This is so we can debug the distribution of codes being learned. if self.record_codes: - codes = logits.argmax(dim = 2).flatten() + codes = codes.flatten() l = codes.shape[0] i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l self.codes[i:i+l] = codes.cpu() @@ -200,7 +176,7 @@ class DiscreteVAE(nn.Module): if self.code_ind >= self.codes.shape[0]: self.code_ind = 0 - return recon_loss, kl_div, out + return recon_loss, commitment_loss, out @register_model @@ -214,4 +190,4 @@ if __name__ == '__main__': #print(o.shape) v = DiscreteVAE(channels=1, normalization=None, positional_dims=1) o=v(torch.randn(1,1,256)) - print(o.shape) + print(o[-1].shape)