Discretization loss attempt
This commit is contained in:
parent
66f99a159c
commit
9c0d7288ea
|
@ -0,0 +1,31 @@
|
|||
import random
|
||||
from math import prod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
|
||||
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
|
||||
# values with the specified expected variance.
|
||||
class DiscretizationLoss(nn.Module):
|
||||
def __init__(self, dim, expected_variance):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dist = torch.distributions.Normal(0, scale=expected_variance)
|
||||
|
||||
def forward(self, x):
|
||||
other_dims = set(range(len(x.shape)))-set([self.dim])
|
||||
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
|
||||
averaged = averaged - averaged.mean()
|
||||
return torch.sum(-self.dist.log_prob(averaged))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
d = DiscretizationLoss(1, 1e-6)
|
||||
v = torch.randn(16, 8192, 500)
|
||||
#for k in range(5):
|
||||
# v[:, random.randint(0,8192), :] += random.random()*100
|
||||
v = F.softmax(v, 1)
|
||||
print(d(v))
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
@ -84,6 +85,7 @@ class DiscreteVAE(nn.Module):
|
|||
self.straight_through = straight_through
|
||||
self.codebook = Quantize(codebook_dim, num_tokens)
|
||||
self.positional_dims = positional_dims
|
||||
self.discrete_loss = DiscretizationLoss(2, 1 / (num_tokens*2))
|
||||
|
||||
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
||||
if positional_dims == 2:
|
||||
|
@ -205,7 +207,7 @@ class DiscreteVAE(nn.Module):
|
|||
):
|
||||
img = self.norm(img)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
sampled, commitment_loss, codes, soft_codes = self.codebook(logits, return_soft_codes=True)
|
||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||
|
||||
if self.training:
|
||||
|
@ -219,6 +221,10 @@ class DiscreteVAE(nn.Module):
|
|||
# reconstruction loss
|
||||
recon_loss = self.loss_fn(img, out, reduction='none')
|
||||
|
||||
# discretization loss
|
||||
disc_loss = self.discrete_loss(soft_codes)
|
||||
|
||||
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
if self.record_codes and self.internal_step % 50 == 0:
|
||||
codes = codes.flatten()
|
||||
|
@ -230,7 +236,7 @@ class DiscreteVAE(nn.Module):
|
|||
self.code_ind = 0
|
||||
self.internal_step += 1
|
||||
|
||||
return recon_loss, commitment_loss, out
|
||||
return recon_loss, commitment_loss, disc_loss, out
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -47,7 +47,7 @@ class Quantize(nn.Module):
|
|||
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input, return_soft_codes=False):
|
||||
if self.balancing_heuristic and self.codes_full:
|
||||
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
|
||||
mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
|
||||
|
@ -68,7 +68,8 @@ class Quantize(nn.Module):
|
|||
- 2 * flatten @ self.embed
|
||||
+ self.embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
_, embed_ind = (-dist).max(1)
|
||||
soft_codes = -dist
|
||||
_, embed_ind = soft_codes.max(1)
|
||||
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
||||
embed_ind = embed_ind.view(*input.shape[:-1])
|
||||
quantize = self.embed_code(embed_ind)
|
||||
|
@ -104,6 +105,9 @@ class Quantize(nn.Module):
|
|||
diff = (quantize.detach() - input).pow(2).mean()
|
||||
quantize = input + (quantize - input).detach()
|
||||
|
||||
if return_soft_codes:
|
||||
return quantize, diff, embed_ind, soft_codes.view(input.shape)
|
||||
else:
|
||||
return quantize, diff, embed_ind
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
|
|
Loading…
Reference in New Issue
Block a user