forked from mrq/DL-Art-School
Integrate with lr_quantizer
This commit is contained in:
parent
82d0e7720e
commit
d9747fe623
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss
|
||||
from models.vqvae.vqvae import Quantize
|
||||
|
@ -76,8 +77,8 @@ class DiscreteVAE(nn.Module):
|
|||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||
record_codes = False,
|
||||
discretization_loss_averaging_steps = 100,
|
||||
encoder_choke=False,
|
||||
choke_dim=128,
|
||||
use_lr_quantizer = False,
|
||||
lr_quantizer_args = {},
|
||||
):
|
||||
super().__init__()
|
||||
has_resblocks = num_resnet_blocks > 0
|
||||
|
@ -85,7 +86,6 @@ class DiscreteVAE(nn.Module):
|
|||
self.num_tokens = num_tokens
|
||||
self.num_layers = num_layers
|
||||
self.straight_through = straight_through
|
||||
self.codebook = Quantize(codebook_dim, num_tokens)
|
||||
self.positional_dims = positional_dims
|
||||
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
|
||||
|
||||
|
@ -134,7 +134,6 @@ class DiscreteVAE(nn.Module):
|
|||
dec_out_chans = hidden_dim
|
||||
innermost_dim = hidden_dim
|
||||
|
||||
|
||||
for _ in range(num_resnet_blocks):
|
||||
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
|
||||
enc_layers.append(ResBlock(innermost_dim, conv, act))
|
||||
|
@ -142,9 +141,6 @@ class DiscreteVAE(nn.Module):
|
|||
if num_resnet_blocks > 0:
|
||||
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
||||
|
||||
if encoder_choke:
|
||||
enc_layers.append(conv(innermost_dim, choke_dim, 1))
|
||||
innermost_dim = choke_dim
|
||||
|
||||
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
||||
dec_layers.append(conv(dec_out_chans, channels, 1))
|
||||
|
@ -154,6 +150,11 @@ class DiscreteVAE(nn.Module):
|
|||
|
||||
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
||||
|
||||
if use_lr_quantizer:
|
||||
self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
|
||||
else:
|
||||
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
||||
|
||||
# take care of normalization within class
|
||||
self.normalization = normalization
|
||||
self.record_codes = record_codes
|
||||
|
@ -186,7 +187,7 @@ class DiscreteVAE(nn.Module):
|
|||
def get_codebook_indices(self, images):
|
||||
img = self.norm(images)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
sampled, codes, _ = self.codebook(logits)
|
||||
return codes
|
||||
|
||||
def decode(
|
||||
|
@ -213,7 +214,7 @@ class DiscreteVAE(nn.Module):
|
|||
def infer(self, img):
|
||||
img = self.norm(img)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
sampled, codes, commitment_loss = self.codebook(logits)
|
||||
return self.decode(codes)
|
||||
|
||||
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
||||
|
@ -225,7 +226,7 @@ class DiscreteVAE(nn.Module):
|
|||
):
|
||||
img = self.norm(img)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes, soft_codes = self.codebook(logits, return_soft_codes=True)
|
||||
sampled, codes, commitment_loss = self.codebook(logits)
|
||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||
|
||||
if self.training:
|
||||
|
@ -239,13 +240,10 @@ class DiscreteVAE(nn.Module):
|
|||
# reconstruction loss
|
||||
recon_loss = self.loss_fn(img, out, reduction='none')
|
||||
|
||||
# discretization loss
|
||||
disc_loss = self.discrete_loss(soft_codes)
|
||||
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
self.log_codes(codes)
|
||||
|
||||
return recon_loss, commitment_loss, disc_loss, out
|
||||
return recon_loss, commitment_loss, out
|
||||
|
||||
def log_codes(self, codes):
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
|
|
|
@ -29,7 +29,7 @@ from utils.util import checkpoint, opt_get
|
|||
|
||||
|
||||
class Quantize(nn.Module):
|
||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False):
|
||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
|
@ -41,6 +41,7 @@ class Quantize(nn.Module):
|
|||
self.codes = None
|
||||
self.max_codes = 64000
|
||||
self.codes_full = False
|
||||
self.new_return_order = new_return_order
|
||||
|
||||
embed = torch.randn(dim, n_embed)
|
||||
self.register_buffer("embed", embed)
|
||||
|
@ -107,6 +108,8 @@ class Quantize(nn.Module):
|
|||
|
||||
if return_soft_codes:
|
||||
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
|
||||
elif self.new_return_order:
|
||||
return quantize, embed_ind, diff
|
||||
else:
|
||||
return quantize, diff, embed_ind
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ linear_attention_transformer
|
|||
orjson
|
||||
einops
|
||||
lambda-networks
|
||||
vector-quantize-pytorch
|
||||
|
||||
# For image generation stuff
|
||||
opencv-python
|
||||
|
|
Loading…
Reference in New Issue
Block a user