Integrate with lr_quantizer

This commit is contained in:
James Betker 2021-11-23 19:48:22 -07:00
parent 82d0e7720e
commit d9747fe623
3 changed files with 17 additions and 15 deletions

View File

@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from vector_quantize_pytorch import VectorQuantize
from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss
from models.vqvae.vqvae import Quantize
@ -76,8 +77,8 @@ class DiscreteVAE(nn.Module):
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
record_codes = False,
discretization_loss_averaging_steps = 100,
encoder_choke=False,
choke_dim=128,
use_lr_quantizer = False,
lr_quantizer_args = {},
):
super().__init__()
has_resblocks = num_resnet_blocks > 0
@ -85,7 +86,6 @@ class DiscreteVAE(nn.Module):
self.num_tokens = num_tokens
self.num_layers = num_layers
self.straight_through = straight_through
self.codebook = Quantize(codebook_dim, num_tokens)
self.positional_dims = positional_dims
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
@ -134,7 +134,6 @@ class DiscreteVAE(nn.Module):
dec_out_chans = hidden_dim
innermost_dim = hidden_dim
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
enc_layers.append(ResBlock(innermost_dim, conv, act))
@ -142,9 +141,6 @@ class DiscreteVAE(nn.Module):
if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
if encoder_choke:
enc_layers.append(conv(innermost_dim, choke_dim, 1))
innermost_dim = choke_dim
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
dec_layers.append(conv(dec_out_chans, channels, 1))
@ -154,6 +150,11 @@ class DiscreteVAE(nn.Module):
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
if use_lr_quantizer:
self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
else:
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
# take care of normalization within class
self.normalization = normalization
self.record_codes = record_codes
@ -186,7 +187,7 @@ class DiscreteVAE(nn.Module):
def get_codebook_indices(self, images):
img = self.norm(images)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
sampled, codes, _ = self.codebook(logits)
return codes
def decode(
@ -213,7 +214,7 @@ class DiscreteVAE(nn.Module):
def infer(self, img):
img = self.norm(img)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
sampled, codes, commitment_loss = self.codebook(logits)
return self.decode(codes)
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
@ -225,7 +226,7 @@ class DiscreteVAE(nn.Module):
):
img = self.norm(img)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes, soft_codes = self.codebook(logits, return_soft_codes=True)
sampled, codes, commitment_loss = self.codebook(logits)
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
if self.training:
@ -239,13 +240,10 @@ class DiscreteVAE(nn.Module):
# reconstruction loss
recon_loss = self.loss_fn(img, out, reduction='none')
# discretization loss
disc_loss = self.discrete_loss(soft_codes)
# This is so we can debug the distribution of codes being learned.
self.log_codes(codes)
return recon_loss, commitment_loss, disc_loss, out
return recon_loss, commitment_loss, out
def log_codes(self, codes):
# This is so we can debug the distribution of codes being learned.

View File

@ -29,7 +29,7 @@ from utils.util import checkpoint, opt_get
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
super().__init__()
self.dim = dim
@ -41,6 +41,7 @@ class Quantize(nn.Module):
self.codes = None
self.max_codes = 64000
self.codes_full = False
self.new_return_order = new_return_order
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
@ -107,6 +108,8 @@ class Quantize(nn.Module):
if return_soft_codes:
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
elif self.new_return_order:
return quantize, embed_ind, diff
else:
return quantize, diff, embed_ind

View File

@ -15,6 +15,7 @@ linear_attention_transformer
orjson
einops
lambda-networks
vector-quantize-pytorch
# For image generation stuff
opencv-python