Make lrdvae use quantized mode in eval()
This commit is contained in:
parent
c28f657ab8
commit
98057b6516
|
@ -158,6 +158,9 @@ class DiscreteVAE(nn.Module):
|
||||||
images = self.decoder(image_embeds)
|
images = self.decoder(image_embeds)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
||||||
|
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
|
||||||
|
# more lossy (but useful for determining network performance).
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
img
|
img
|
||||||
|
@ -166,7 +169,12 @@ class DiscreteVAE(nn.Module):
|
||||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
sampled, commitment_loss, codes = self.codebook(logits)
|
sampled, commitment_loss, codes = self.codebook(logits)
|
||||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||||
out = self.decoder(sampled)
|
|
||||||
|
if self.training:
|
||||||
|
out = self.decoder(sampled)
|
||||||
|
else:
|
||||||
|
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||||
|
out = self.decode(codes)
|
||||||
|
|
||||||
# reconstruction loss
|
# reconstruction loss
|
||||||
recon_loss = self.loss_fn(img, out)
|
recon_loss = self.loss_fn(img, out)
|
||||||
|
@ -195,5 +203,6 @@ if __name__ == '__main__':
|
||||||
#o=v(torch.randn(1,3,256,256))
|
#o=v(torch.randn(1,3,256,256))
|
||||||
#print(o.shape)
|
#print(o.shape)
|
||||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
|
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
|
||||||
|
v.eval()
|
||||||
o=v(torch.randn(1,1,256))
|
o=v(torch.randn(1,1,256))
|
||||||
print(o[-1].shape)
|
print(o[-1].shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user