Add dvae balancing heuristic

This commit is contained in:
James Betker 2021-09-23 21:19:36 -06:00
parent e24c619387
commit c5297ccec6
3 changed files with 31 additions and 4 deletions

View File

@ -106,7 +106,7 @@ class DiffusionDVAE(nn.Module):
self.scale_steps = scale_steps
self.encoder = DiscreteEncoder(spectrogram_channels, model_channels*4, quantize_dim, dropout, scale_steps)
self.quantizer = Quantize(quantize_dim, num_discrete_codes)
self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True)
# For recording codebook usage.
self.codes = torch.zeros((131072,), dtype=torch.long)
self.code_ind = 0

View File

@ -15,7 +15,7 @@
# Borrowed from https://github.com/rosinality/vq-vae-2-pytorch
# Which was itself orrowed from https://github.com/deepmind/sonnet
# Which was itself borrowed from https://github.com/deepmind/sonnet
import torch
@ -29,7 +29,7 @@ from utils.util import checkpoint, opt_get
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False):
super().__init__()
self.dim = dim
@ -37,12 +37,31 @@ class Quantize(nn.Module):
self.decay = decay
self.eps = eps
self.balancing_heuristic = balancing_heuristic
self.codes = None
self.max_codes = 64000
self.codes_full = False
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
def forward(self, input):
if self.codes_full:
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
ep = self.embed.permute(1,0)
ea = self.embed_avg.permute(1,0)
rand_embed = torch.randn_like(ep) * mask
self.embed = (ep * ~mask + rand_embed).permute(1,0)
self.embed_avg = (ea * ~mask + rand_embed).permute(1,0)
self.cluster_size = self.cluster_size * ~mask.squeeze()
if torch.any(mask):
print(f"Reset {torch.sum(mask)} embedding codes.")
self.codes = None
self.codes_full = False
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
@ -54,6 +73,14 @@ class Quantize(nn.Module):
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.codes is None:
self.codes = embed_ind.flatten()
else:
self.codes = torch.cat([self.codes, embed_ind.flatten()])
if len(self.codes) > self.max_codes:
self.codes = self.codes[-self.max_codes:]
self.codes_full = True
if self.training:
embed_onehot_sum = embed_onehot.sum(0)
embed_sum = flatten.transpose(0, 1) @ embed_onehot

View File

@ -284,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_dvae_clips.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()