From 9c0d7288ea9452bc10999196c51e63f0a80045b7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 4 Oct 2021 20:59:21 -0600 Subject: [PATCH] Discretization loss attempt --- .../discretization_loss.py | 31 +++++++++++++++++++ codes/models/gpt_voice/lucidrains_dvae.py | 10 ++++-- codes/models/vqvae/vqvae.py | 10 ++++-- 3 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 codes/models/gpt_voice/dvae_arch_playground/discretization_loss.py diff --git a/codes/models/gpt_voice/dvae_arch_playground/discretization_loss.py b/codes/models/gpt_voice/dvae_arch_playground/discretization_loss.py new file mode 100644 index 00000000..40b0cb7a --- /dev/null +++ b/codes/models/gpt_voice/dvae_arch_playground/discretization_loss.py @@ -0,0 +1,31 @@ +import random +from math import prod + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Fits a soft-discretized input to a normal-PDF across the specified dimension. +# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete +# values with the specified expected variance. +class DiscretizationLoss(nn.Module): + def __init__(self, dim, expected_variance): + super().__init__() + self.dim = dim + self.dist = torch.distributions.Normal(0, scale=expected_variance) + + def forward(self, x): + other_dims = set(range(len(x.shape)))-set([self.dim]) + averaged = x.sum(dim=tuple(other_dims)) / x.sum() + averaged = averaged - averaged.mean() + return torch.sum(-self.dist.log_prob(averaged)) + + +if __name__ == '__main__': + d = DiscretizationLoss(1, 1e-6) + v = torch.randn(16, 8192, 500) + #for k in range(5): + # v[:, random.randint(0,8192), :] += random.random()*100 + v = F.softmax(v, 1) + print(d(v)) \ No newline at end of file diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 812c45a6..75ce5aba 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from einops import rearrange from torch import einsum +from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss from models.vqvae.vqvae import Quantize from trainer.networks import register_model from utils.util import opt_get @@ -84,6 +85,7 @@ class DiscreteVAE(nn.Module): self.straight_through = straight_through self.codebook = Quantize(codebook_dim, num_tokens) self.positional_dims = positional_dims + self.discrete_loss = DiscretizationLoss(2, 1 / (num_tokens*2)) assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. if positional_dims == 2: @@ -205,7 +207,7 @@ class DiscreteVAE(nn.Module): ): img = self.norm(img) logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) - sampled, commitment_loss, codes = self.codebook(logits) + sampled, commitment_loss, codes, soft_codes = self.codebook(logits, return_soft_codes=True) sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) if self.training: @@ -219,6 +221,10 @@ class DiscreteVAE(nn.Module): # reconstruction loss recon_loss = self.loss_fn(img, out, reduction='none') + # discretization loss + disc_loss = self.discrete_loss(soft_codes) + + # This is so we can debug the distribution of codes being learned. if self.record_codes and self.internal_step % 50 == 0: codes = codes.flatten() @@ -230,7 +236,7 @@ class DiscreteVAE(nn.Module): self.code_ind = 0 self.internal_step += 1 - return recon_loss, commitment_loss, out + return recon_loss, commitment_loss, disc_loss, out @register_model diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index 103d8d8b..e90a7edb 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -47,7 +47,7 @@ class Quantize(nn.Module): self.register_buffer("cluster_size", torch.zeros(n_embed)) self.register_buffer("embed_avg", embed.clone()) - def forward(self, input): + def forward(self, input, return_soft_codes=False): if self.balancing_heuristic and self.codes_full: h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes) mask = torch.logical_or(h > .9, h < .01).unsqueeze(1) @@ -68,7 +68,8 @@ class Quantize(nn.Module): - 2 * flatten @ self.embed + self.embed.pow(2).sum(0, keepdim=True) ) - _, embed_ind = (-dist).max(1) + soft_codes = -dist + _, embed_ind = soft_codes.max(1) embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) embed_ind = embed_ind.view(*input.shape[:-1]) quantize = self.embed_code(embed_ind) @@ -104,7 +105,10 @@ class Quantize(nn.Module): diff = (quantize.detach() - input).pow(2).mean() quantize = input + (quantize - input).detach() - return quantize, diff, embed_ind + if return_soft_codes: + return quantize, diff, embed_ind, soft_codes.view(input.shape) + else: + return quantize, diff, embed_ind def embed_code(self, embed_id): return F.embedding(embed_id, self.embed.transpose(0, 1))