From d9747fe6238aa2042c8ff2057e80c30a8843e731 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 23 Nov 2021 19:48:22 -0700 Subject: [PATCH] Integrate with lr_quantizer --- codes/models/gpt_voice/lucidrains_dvae.py | 26 +++++++++++------------ codes/models/vqvae/vqvae.py | 5 ++++- codes/requirements.txt | 1 + 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index a8cc36c7..1d21a3bc 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import einsum +from vector_quantize_pytorch import VectorQuantize from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss from models.vqvae.vqvae import Quantize @@ -76,8 +77,8 @@ class DiscreteVAE(nn.Module): normalization = None, # ((0.5,) * 3, (0.5,) * 3), record_codes = False, discretization_loss_averaging_steps = 100, - encoder_choke=False, - choke_dim=128, + use_lr_quantizer = False, + lr_quantizer_args = {}, ): super().__init__() has_resblocks = num_resnet_blocks > 0 @@ -85,7 +86,6 @@ class DiscreteVAE(nn.Module): self.num_tokens = num_tokens self.num_layers = num_layers self.straight_through = straight_through - self.codebook = Quantize(codebook_dim, num_tokens) self.positional_dims = positional_dims self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps) @@ -134,7 +134,6 @@ class DiscreteVAE(nn.Module): dec_out_chans = hidden_dim innermost_dim = hidden_dim - for _ in range(num_resnet_blocks): dec_layers.insert(0, ResBlock(innermost_dim, conv, act)) enc_layers.append(ResBlock(innermost_dim, conv, act)) @@ -142,9 +141,6 @@ class DiscreteVAE(nn.Module): if num_resnet_blocks > 0: dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) - if encoder_choke: - enc_layers.append(conv(innermost_dim, choke_dim, 1)) - innermost_dim = choke_dim enc_layers.append(conv(innermost_dim, codebook_dim, 1)) dec_layers.append(conv(dec_out_chans, channels, 1)) @@ -154,6 +150,11 @@ class DiscreteVAE(nn.Module): self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss + if use_lr_quantizer: + self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args) + else: + self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True) + # take care of normalization within class self.normalization = normalization self.record_codes = record_codes @@ -186,7 +187,7 @@ class DiscreteVAE(nn.Module): def get_codebook_indices(self, images): img = self.norm(images) logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) - sampled, commitment_loss, codes = self.codebook(logits) + sampled, codes, _ = self.codebook(logits) return codes def decode( @@ -213,7 +214,7 @@ class DiscreteVAE(nn.Module): def infer(self, img): img = self.norm(img) logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) - sampled, commitment_loss, codes = self.codebook(logits) + sampled, codes, commitment_loss = self.codebook(logits) return self.decode(codes) # Note: This module is not meant to be run in forward() except while training. It has special logic which performs @@ -225,7 +226,7 @@ class DiscreteVAE(nn.Module): ): img = self.norm(img) logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) - sampled, commitment_loss, codes, soft_codes = self.codebook(logits, return_soft_codes=True) + sampled, codes, commitment_loss = self.codebook(logits) sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) if self.training: @@ -239,13 +240,10 @@ class DiscreteVAE(nn.Module): # reconstruction loss recon_loss = self.loss_fn(img, out, reduction='none') - # discretization loss - disc_loss = self.discrete_loss(soft_codes) - # This is so we can debug the distribution of codes being learned. self.log_codes(codes) - return recon_loss, commitment_loss, disc_loss, out + return recon_loss, commitment_loss, out def log_codes(self, codes): # This is so we can debug the distribution of codes being learned. diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index 7d735f9e..55dfdbbd 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -29,7 +29,7 @@ from utils.util import checkpoint, opt_get class Quantize(nn.Module): - def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False): super().__init__() self.dim = dim @@ -41,6 +41,7 @@ class Quantize(nn.Module): self.codes = None self.max_codes = 64000 self.codes_full = False + self.new_return_order = new_return_order embed = torch.randn(dim, n_embed) self.register_buffer("embed", embed) @@ -107,6 +108,8 @@ class Quantize(nn.Module): if return_soft_codes: return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,)) + elif self.new_return_order: + return quantize, embed_ind, diff else: return quantize, diff, embed_ind diff --git a/codes/requirements.txt b/codes/requirements.txt index adeff1d9..1b1a6951 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -15,6 +15,7 @@ linear_attention_transformer orjson einops lambda-networks +vector-quantize-pytorch # For image generation stuff opencv-python