Fix LRDVAE bug with quantizer integration

This commit is contained in:
James Betker 2021-08-11 16:17:22 -06:00
parent f04a7bdf63
commit 20586a8edc

View File

@ -63,7 +63,7 @@ class DiscreteVAE(nn.Module):
self.num_tokens = num_tokens
self.num_layers = num_layers
self.straight_through = straight_through
self.codebook = Quantize(num_tokens, codebook_dim)
self.codebook = Quantize(codebook_dim, num_tokens)
self.positional_dims = positional_dims
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
@ -98,7 +98,7 @@ class DiscreteVAE(nn.Module):
if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
enc_layers.append(conv(enc_chans[-1], num_tokens, 1))
enc_layers.append(conv(enc_chans[-1], codebook_dim, 1))
dec_layers.append(conv(dec_chans[-1], channels, 1))
self.encoder = nn.Sequential(*enc_layers)
@ -112,6 +112,7 @@ class DiscreteVAE(nn.Module):
if record_codes:
self.codes = torch.zeros((32768,), dtype=torch.long)
self.code_ind = 0
self.internal_step = 0
def norm(self, images):
if not self.normalization is not None:
@ -171,7 +172,7 @@ class DiscreteVAE(nn.Module):
recon_loss = self.loss_fn(img, out)
# This is so we can debug the distribution of codes being learned.
if self.record_codes:
if self.record_codes and self.internal_step % 50 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
@ -179,6 +180,7 @@ class DiscreteVAE(nn.Module):
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
self.code_ind = 0
self.internal_step += 1
return recon_loss, commitment_loss, out
@ -192,6 +194,6 @@ if __name__ == '__main__':
#v = DiscreteVAE()
#o=v(torch.randn(1,3,256,256))
#print(o.shape)
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1)
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
o=v(torch.randn(1,1,256))
print(o[-1].shape)