forked from mrq/DL-Art-School
Discretization loss attempt
This commit is contained in:
parent
66f99a159c
commit
9c0d7288ea
|
@ -0,0 +1,31 @@
|
||||||
|
import random
|
||||||
|
from math import prod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
|
||||||
|
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
|
||||||
|
# values with the specified expected variance.
|
||||||
|
class DiscretizationLoss(nn.Module):
|
||||||
|
def __init__(self, dim, expected_variance):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dist = torch.distributions.Normal(0, scale=expected_variance)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
other_dims = set(range(len(x.shape)))-set([self.dim])
|
||||||
|
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
|
||||||
|
averaged = averaged - averaged.mean()
|
||||||
|
return torch.sum(-self.dist.log_prob(averaged))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
d = DiscretizationLoss(1, 1e-6)
|
||||||
|
v = torch.randn(16, 8192, 500)
|
||||||
|
#for k in range(5):
|
||||||
|
# v[:, random.randint(0,8192), :] += random.random()*100
|
||||||
|
v = F.softmax(v, 1)
|
||||||
|
print(d(v))
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
|
from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
@ -84,6 +85,7 @@ class DiscreteVAE(nn.Module):
|
||||||
self.straight_through = straight_through
|
self.straight_through = straight_through
|
||||||
self.codebook = Quantize(codebook_dim, num_tokens)
|
self.codebook = Quantize(codebook_dim, num_tokens)
|
||||||
self.positional_dims = positional_dims
|
self.positional_dims = positional_dims
|
||||||
|
self.discrete_loss = DiscretizationLoss(2, 1 / (num_tokens*2))
|
||||||
|
|
||||||
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
||||||
if positional_dims == 2:
|
if positional_dims == 2:
|
||||||
|
@ -205,7 +207,7 @@ class DiscreteVAE(nn.Module):
|
||||||
):
|
):
|
||||||
img = self.norm(img)
|
img = self.norm(img)
|
||||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
sampled, commitment_loss, codes = self.codebook(logits)
|
sampled, commitment_loss, codes, soft_codes = self.codebook(logits, return_soft_codes=True)
|
||||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
|
@ -219,6 +221,10 @@ class DiscreteVAE(nn.Module):
|
||||||
# reconstruction loss
|
# reconstruction loss
|
||||||
recon_loss = self.loss_fn(img, out, reduction='none')
|
recon_loss = self.loss_fn(img, out, reduction='none')
|
||||||
|
|
||||||
|
# discretization loss
|
||||||
|
disc_loss = self.discrete_loss(soft_codes)
|
||||||
|
|
||||||
|
|
||||||
# This is so we can debug the distribution of codes being learned.
|
# This is so we can debug the distribution of codes being learned.
|
||||||
if self.record_codes and self.internal_step % 50 == 0:
|
if self.record_codes and self.internal_step % 50 == 0:
|
||||||
codes = codes.flatten()
|
codes = codes.flatten()
|
||||||
|
@ -230,7 +236,7 @@ class DiscreteVAE(nn.Module):
|
||||||
self.code_ind = 0
|
self.code_ind = 0
|
||||||
self.internal_step += 1
|
self.internal_step += 1
|
||||||
|
|
||||||
return recon_loss, commitment_loss, out
|
return recon_loss, commitment_loss, disc_loss, out
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
@ -47,7 +47,7 @@ class Quantize(nn.Module):
|
||||||
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
||||||
self.register_buffer("embed_avg", embed.clone())
|
self.register_buffer("embed_avg", embed.clone())
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input, return_soft_codes=False):
|
||||||
if self.balancing_heuristic and self.codes_full:
|
if self.balancing_heuristic and self.codes_full:
|
||||||
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
|
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
|
||||||
mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
|
mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
|
||||||
|
@ -68,7 +68,8 @@ class Quantize(nn.Module):
|
||||||
- 2 * flatten @ self.embed
|
- 2 * flatten @ self.embed
|
||||||
+ self.embed.pow(2).sum(0, keepdim=True)
|
+ self.embed.pow(2).sum(0, keepdim=True)
|
||||||
)
|
)
|
||||||
_, embed_ind = (-dist).max(1)
|
soft_codes = -dist
|
||||||
|
_, embed_ind = soft_codes.max(1)
|
||||||
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
||||||
embed_ind = embed_ind.view(*input.shape[:-1])
|
embed_ind = embed_ind.view(*input.shape[:-1])
|
||||||
quantize = self.embed_code(embed_ind)
|
quantize = self.embed_code(embed_ind)
|
||||||
|
@ -104,7 +105,10 @@ class Quantize(nn.Module):
|
||||||
diff = (quantize.detach() - input).pow(2).mean()
|
diff = (quantize.detach() - input).pow(2).mean()
|
||||||
quantize = input + (quantize - input).detach()
|
quantize = input + (quantize - input).detach()
|
||||||
|
|
||||||
return quantize, diff, embed_ind
|
if return_soft_codes:
|
||||||
|
return quantize, diff, embed_ind, soft_codes.view(input.shape)
|
||||||
|
else:
|
||||||
|
return quantize, diff, embed_ind
|
||||||
|
|
||||||
def embed_code(self, embed_id):
|
def embed_code(self, embed_id):
|
||||||
return F.embedding(embed_id, self.embed.transpose(0, 1))
|
return F.embedding(embed_id, self.embed.transpose(0, 1))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user