forked from mrq/DL-Art-School
Fix LRDVAE bug with quantizer integration
This commit is contained in:
parent
f04a7bdf63
commit
20586a8edc
|
@ -63,7 +63,7 @@ class DiscreteVAE(nn.Module):
|
|||
self.num_tokens = num_tokens
|
||||
self.num_layers = num_layers
|
||||
self.straight_through = straight_through
|
||||
self.codebook = Quantize(num_tokens, codebook_dim)
|
||||
self.codebook = Quantize(codebook_dim, num_tokens)
|
||||
self.positional_dims = positional_dims
|
||||
|
||||
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
||||
|
@ -98,7 +98,7 @@ class DiscreteVAE(nn.Module):
|
|||
if num_resnet_blocks > 0:
|
||||
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
|
||||
|
||||
enc_layers.append(conv(enc_chans[-1], num_tokens, 1))
|
||||
enc_layers.append(conv(enc_chans[-1], codebook_dim, 1))
|
||||
dec_layers.append(conv(dec_chans[-1], channels, 1))
|
||||
|
||||
self.encoder = nn.Sequential(*enc_layers)
|
||||
|
@ -112,6 +112,7 @@ class DiscreteVAE(nn.Module):
|
|||
if record_codes:
|
||||
self.codes = torch.zeros((32768,), dtype=torch.long)
|
||||
self.code_ind = 0
|
||||
self.internal_step = 0
|
||||
|
||||
def norm(self, images):
|
||||
if not self.normalization is not None:
|
||||
|
@ -171,7 +172,7 @@ class DiscreteVAE(nn.Module):
|
|||
recon_loss = self.loss_fn(img, out)
|
||||
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
if self.record_codes:
|
||||
if self.record_codes and self.internal_step % 50 == 0:
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
|
@ -179,6 +180,7 @@ class DiscreteVAE(nn.Module):
|
|||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
self.code_ind = 0
|
||||
self.internal_step += 1
|
||||
|
||||
return recon_loss, commitment_loss, out
|
||||
|
||||
|
@ -192,6 +194,6 @@ if __name__ == '__main__':
|
|||
#v = DiscreteVAE()
|
||||
#o=v(torch.randn(1,3,256,256))
|
||||
#print(o.shape)
|
||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1)
|
||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
|
||||
o=v(torch.randn(1,1,256))
|
||||
print(o[-1].shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user