asdf
This commit is contained in:
parent
3cd6c7f428
commit
7ea84f1ac3
|
@ -8,7 +8,6 @@ import torch.nn.functional as F
|
|||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
from models.gpt_voice.lucidrains_dvae import DiscretizationLoss
|
||||
from models.vqvae.vector_quantizer import VectorQuantize
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
|
@ -81,7 +80,6 @@ class DiscreteVAE(nn.Module):
|
|||
self.num_layers = num_layers
|
||||
self.straight_through = straight_through
|
||||
self.positional_dims = positional_dims
|
||||
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
|
||||
|
||||
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
||||
if positional_dims == 2:
|
||||
|
|
Loading…
Reference in New Issue
Block a user