Make lrdvae use quantized mode in eval()

This commit is contained in:
James Betker 2021-08-14 23:43:01 -06:00
parent c28f657ab8
commit 98057b6516

View File

@ -158,6 +158,9 @@ class DiscreteVAE(nn.Module):
images = self.decoder(image_embeds) images = self.decoder(image_embeds)
return images return images
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
# more lossy (but useful for determining network performance).
def forward( def forward(
self, self,
img img
@ -166,7 +169,12 @@ class DiscreteVAE(nn.Module):
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits) sampled, commitment_loss, codes = self.codebook(logits)
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
out = self.decoder(sampled)
if self.training:
out = self.decoder(sampled)
else:
# This is non-differentiable, but gives a better idea of how the network is actually performing.
out = self.decode(codes)
# reconstruction loss # reconstruction loss
recon_loss = self.loss_fn(img, out) recon_loss = self.loss_fn(img, out)
@ -195,5 +203,6 @@ if __name__ == '__main__':
#o=v(torch.randn(1,3,256,256)) #o=v(torch.randn(1,3,256,256))
#print(o.shape) #print(o.shape)
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256) v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
v.eval()
o=v(torch.randn(1,1,256)) o=v(torch.randn(1,1,256))
print(o[-1].shape) print(o[-1].shape)