forked from mrq/DL-Art-School
Integrate with lr_quantizer
This commit is contained in:
parent
82d0e7720e
commit
d9747fe623
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
from vector_quantize_pytorch import VectorQuantize
|
||||||
|
|
||||||
from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss
|
from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
|
@ -76,8 +77,8 @@ class DiscreteVAE(nn.Module):
|
||||||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||||
record_codes = False,
|
record_codes = False,
|
||||||
discretization_loss_averaging_steps = 100,
|
discretization_loss_averaging_steps = 100,
|
||||||
encoder_choke=False,
|
use_lr_quantizer = False,
|
||||||
choke_dim=128,
|
lr_quantizer_args = {},
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
has_resblocks = num_resnet_blocks > 0
|
has_resblocks = num_resnet_blocks > 0
|
||||||
|
@ -85,7 +86,6 @@ class DiscreteVAE(nn.Module):
|
||||||
self.num_tokens = num_tokens
|
self.num_tokens = num_tokens
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.straight_through = straight_through
|
self.straight_through = straight_through
|
||||||
self.codebook = Quantize(codebook_dim, num_tokens)
|
|
||||||
self.positional_dims = positional_dims
|
self.positional_dims = positional_dims
|
||||||
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
|
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
|
||||||
|
|
||||||
|
@ -134,7 +134,6 @@ class DiscreteVAE(nn.Module):
|
||||||
dec_out_chans = hidden_dim
|
dec_out_chans = hidden_dim
|
||||||
innermost_dim = hidden_dim
|
innermost_dim = hidden_dim
|
||||||
|
|
||||||
|
|
||||||
for _ in range(num_resnet_blocks):
|
for _ in range(num_resnet_blocks):
|
||||||
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
|
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
|
||||||
enc_layers.append(ResBlock(innermost_dim, conv, act))
|
enc_layers.append(ResBlock(innermost_dim, conv, act))
|
||||||
|
@ -142,9 +141,6 @@ class DiscreteVAE(nn.Module):
|
||||||
if num_resnet_blocks > 0:
|
if num_resnet_blocks > 0:
|
||||||
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
||||||
|
|
||||||
if encoder_choke:
|
|
||||||
enc_layers.append(conv(innermost_dim, choke_dim, 1))
|
|
||||||
innermost_dim = choke_dim
|
|
||||||
|
|
||||||
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
||||||
dec_layers.append(conv(dec_out_chans, channels, 1))
|
dec_layers.append(conv(dec_out_chans, channels, 1))
|
||||||
|
@ -154,6 +150,11 @@ class DiscreteVAE(nn.Module):
|
||||||
|
|
||||||
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
||||||
|
|
||||||
|
if use_lr_quantizer:
|
||||||
|
self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
|
||||||
|
else:
|
||||||
|
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
||||||
|
|
||||||
# take care of normalization within class
|
# take care of normalization within class
|
||||||
self.normalization = normalization
|
self.normalization = normalization
|
||||||
self.record_codes = record_codes
|
self.record_codes = record_codes
|
||||||
|
@ -186,7 +187,7 @@ class DiscreteVAE(nn.Module):
|
||||||
def get_codebook_indices(self, images):
|
def get_codebook_indices(self, images):
|
||||||
img = self.norm(images)
|
img = self.norm(images)
|
||||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
sampled, commitment_loss, codes = self.codebook(logits)
|
sampled, codes, _ = self.codebook(logits)
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
|
@ -213,7 +214,7 @@ class DiscreteVAE(nn.Module):
|
||||||
def infer(self, img):
|
def infer(self, img):
|
||||||
img = self.norm(img)
|
img = self.norm(img)
|
||||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
sampled, commitment_loss, codes = self.codebook(logits)
|
sampled, codes, commitment_loss = self.codebook(logits)
|
||||||
return self.decode(codes)
|
return self.decode(codes)
|
||||||
|
|
||||||
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
||||||
|
@ -225,7 +226,7 @@ class DiscreteVAE(nn.Module):
|
||||||
):
|
):
|
||||||
img = self.norm(img)
|
img = self.norm(img)
|
||||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
sampled, commitment_loss, codes, soft_codes = self.codebook(logits, return_soft_codes=True)
|
sampled, codes, commitment_loss = self.codebook(logits)
|
||||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
|
@ -239,13 +240,10 @@ class DiscreteVAE(nn.Module):
|
||||||
# reconstruction loss
|
# reconstruction loss
|
||||||
recon_loss = self.loss_fn(img, out, reduction='none')
|
recon_loss = self.loss_fn(img, out, reduction='none')
|
||||||
|
|
||||||
# discretization loss
|
|
||||||
disc_loss = self.discrete_loss(soft_codes)
|
|
||||||
|
|
||||||
# This is so we can debug the distribution of codes being learned.
|
# This is so we can debug the distribution of codes being learned.
|
||||||
self.log_codes(codes)
|
self.log_codes(codes)
|
||||||
|
|
||||||
return recon_loss, commitment_loss, disc_loss, out
|
return recon_loss, commitment_loss, out
|
||||||
|
|
||||||
def log_codes(self, codes):
|
def log_codes(self, codes):
|
||||||
# This is so we can debug the distribution of codes being learned.
|
# This is so we can debug the distribution of codes being learned.
|
||||||
|
|
|
@ -29,7 +29,7 @@ from utils.util import checkpoint, opt_get
|
||||||
|
|
||||||
|
|
||||||
class Quantize(nn.Module):
|
class Quantize(nn.Module):
|
||||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False):
|
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
@ -41,6 +41,7 @@ class Quantize(nn.Module):
|
||||||
self.codes = None
|
self.codes = None
|
||||||
self.max_codes = 64000
|
self.max_codes = 64000
|
||||||
self.codes_full = False
|
self.codes_full = False
|
||||||
|
self.new_return_order = new_return_order
|
||||||
|
|
||||||
embed = torch.randn(dim, n_embed)
|
embed = torch.randn(dim, n_embed)
|
||||||
self.register_buffer("embed", embed)
|
self.register_buffer("embed", embed)
|
||||||
|
@ -107,6 +108,8 @@ class Quantize(nn.Module):
|
||||||
|
|
||||||
if return_soft_codes:
|
if return_soft_codes:
|
||||||
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
|
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
|
||||||
|
elif self.new_return_order:
|
||||||
|
return quantize, embed_ind, diff
|
||||||
else:
|
else:
|
||||||
return quantize, diff, embed_ind
|
return quantize, diff, embed_ind
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ linear_attention_transformer
|
||||||
orjson
|
orjson
|
||||||
einops
|
einops
|
||||||
lambda-networks
|
lambda-networks
|
||||||
|
vector-quantize-pytorch
|
||||||
|
|
||||||
# For image generation stuff
|
# For image generation stuff
|
||||||
opencv-python
|
opencv-python
|
||||||
|
|
Loading…
Reference in New Issue
Block a user