This commit is contained in:
James Betker 2022-03-03 13:43:44 -07:00
parent 3cd6c7f428
commit 7ea84f1ac3

View File

@ -8,7 +8,6 @@ import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from models.gpt_voice.lucidrains_dvae import DiscretizationLoss
from models.vqvae.vector_quantizer import VectorQuantize
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
@ -81,7 +80,6 @@ class DiscreteVAE(nn.Module):
self.num_layers = num_layers
self.straight_through = straight_through
self.positional_dims = positional_dims
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
if positional_dims == 2: