Discretization loss attempt

This commit is contained in:
James Betker 2021-10-04 20:59:21 -06:00
parent 66f99a159c
commit 9c0d7288ea
3 changed files with 46 additions and 5 deletions

View File

@ -0,0 +1,31 @@
import random
from math import prod
import torch
import torch.nn as nn
import torch.nn.functional as F
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
# values with the specified expected variance.
class DiscretizationLoss(nn.Module):
def __init__(self, dim, expected_variance):
super().__init__()
self.dim = dim
self.dist = torch.distributions.Normal(0, scale=expected_variance)
def forward(self, x):
other_dims = set(range(len(x.shape)))-set([self.dim])
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
averaged = averaged - averaged.mean()
return torch.sum(-self.dist.log_prob(averaged))
if __name__ == '__main__':
d = DiscretizationLoss(1, 1e-6)
v = torch.randn(16, 8192, 500)
#for k in range(5):
# v[:, random.randint(0,8192), :] += random.random()*100
v = F.softmax(v, 1)
print(d(v))

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import opt_get
@ -84,6 +85,7 @@ class DiscreteVAE(nn.Module):
self.straight_through = straight_through
self.codebook = Quantize(codebook_dim, num_tokens)
self.positional_dims = positional_dims
self.discrete_loss = DiscretizationLoss(2, 1 / (num_tokens*2))
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
if positional_dims == 2:
@ -205,7 +207,7 @@ class DiscreteVAE(nn.Module):
):
img = self.norm(img)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
sampled, commitment_loss, codes, soft_codes = self.codebook(logits, return_soft_codes=True)
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
if self.training:
@ -219,6 +221,10 @@ class DiscreteVAE(nn.Module):
# reconstruction loss
recon_loss = self.loss_fn(img, out, reduction='none')
# discretization loss
disc_loss = self.discrete_loss(soft_codes)
# This is so we can debug the distribution of codes being learned.
if self.record_codes and self.internal_step % 50 == 0:
codes = codes.flatten()
@ -230,7 +236,7 @@ class DiscreteVAE(nn.Module):
self.code_ind = 0
self.internal_step += 1
return recon_loss, commitment_loss, out
return recon_loss, commitment_loss, disc_loss, out
@register_model

View File

@ -47,7 +47,7 @@ class Quantize(nn.Module):
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
def forward(self, input):
def forward(self, input, return_soft_codes=False):
if self.balancing_heuristic and self.codes_full:
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
@ -68,7 +68,8 @@ class Quantize(nn.Module):
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
soft_codes = -dist
_, embed_ind = soft_codes.max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
@ -104,6 +105,9 @@ class Quantize(nn.Module):
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
if return_soft_codes:
return quantize, diff, embed_ind, soft_codes.view(input.shape)
else:
return quantize, diff, embed_ind
def embed_code(self, embed_id):