asdf
This commit is contained in:
parent
3cd6c7f428
commit
7ea84f1ac3
|
@ -8,7 +8,6 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
from models.gpt_voice.lucidrains_dvae import DiscretizationLoss
|
|
||||||
from models.vqvae.vector_quantizer import VectorQuantize
|
from models.vqvae.vector_quantizer import VectorQuantize
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
@ -81,7 +80,6 @@ class DiscreteVAE(nn.Module):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.straight_through = straight_through
|
self.straight_through = straight_through
|
||||||
self.positional_dims = positional_dims
|
self.positional_dims = positional_dims
|
||||||
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
|
|
||||||
|
|
||||||
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
||||||
if positional_dims == 2:
|
if positional_dims == 2:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user