diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index d297fa00..1dda5fdd 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -63,7 +63,7 @@ class DiscreteVAE(nn.Module): self.num_tokens = num_tokens self.num_layers = num_layers self.straight_through = straight_through - self.codebook = Quantize(num_tokens, codebook_dim) + self.codebook = Quantize(codebook_dim, num_tokens) self.positional_dims = positional_dims assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. @@ -98,7 +98,7 @@ class DiscreteVAE(nn.Module): if num_resnet_blocks > 0: dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1)) - enc_layers.append(conv(enc_chans[-1], num_tokens, 1)) + enc_layers.append(conv(enc_chans[-1], codebook_dim, 1)) dec_layers.append(conv(dec_chans[-1], channels, 1)) self.encoder = nn.Sequential(*enc_layers) @@ -112,6 +112,7 @@ class DiscreteVAE(nn.Module): if record_codes: self.codes = torch.zeros((32768,), dtype=torch.long) self.code_ind = 0 + self.internal_step = 0 def norm(self, images): if not self.normalization is not None: @@ -171,7 +172,7 @@ class DiscreteVAE(nn.Module): recon_loss = self.loss_fn(img, out) # This is so we can debug the distribution of codes being learned. - if self.record_codes: + if self.record_codes and self.internal_step % 50 == 0: codes = codes.flatten() l = codes.shape[0] i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l @@ -179,6 +180,7 @@ class DiscreteVAE(nn.Module): self.code_ind = self.code_ind + l if self.code_ind >= self.codes.shape[0]: self.code_ind = 0 + self.internal_step += 1 return recon_loss, commitment_loss, out @@ -192,6 +194,6 @@ if __name__ == '__main__': #v = DiscreteVAE() #o=v(torch.randn(1,3,256,256)) #print(o.shape) - v = DiscreteVAE(channels=1, normalization=None, positional_dims=1) + v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256) o=v(torch.randn(1,1,256)) print(o[-1].shape)