forked from mrq/DL-Art-School
Make lrdvae use quantized mode in eval()
This commit is contained in:
parent
c28f657ab8
commit
98057b6516
|
@ -158,6 +158,9 @@ class DiscreteVAE(nn.Module):
|
|||
images = self.decoder(image_embeds)
|
||||
return images
|
||||
|
||||
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
||||
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
|
||||
# more lossy (but useful for determining network performance).
|
||||
def forward(
|
||||
self,
|
||||
img
|
||||
|
@ -166,7 +169,12 @@ class DiscreteVAE(nn.Module):
|
|||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||
out = self.decoder(sampled)
|
||||
|
||||
if self.training:
|
||||
out = self.decoder(sampled)
|
||||
else:
|
||||
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||
out = self.decode(codes)
|
||||
|
||||
# reconstruction loss
|
||||
recon_loss = self.loss_fn(img, out)
|
||||
|
@ -195,5 +203,6 @@ if __name__ == '__main__':
|
|||
#o=v(torch.randn(1,3,256,256))
|
||||
#print(o.shape)
|
||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
|
||||
v.eval()
|
||||
o=v(torch.randn(1,1,256))
|
||||
print(o[-1].shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user