diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index 1e90a47f..1f04de0e 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -1,5 +1,7 @@ from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16 +from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear +from models.diffusion.respace import SpacedDiffusion, space_timesteps from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \ Downsample, Upsample import torch @@ -253,7 +255,8 @@ class DiffusionDVAE(nn.Module): ) def get_debug_values(self, step, __): - return {'histogram_codes': self.codes} + # Note: this is very poor design, but quantizer.get_temperature not only retrieves the temperature, it also updates the step and thus it is extremely important that this function get called regularly. + return {'histogram_codes': self.codes, 'quantizer_temperature': self.quantizer.get_temperature(step)} @torch.no_grad() @eval_decorator diff --git a/codes/models/vqvae/gumbel_quantizer.py b/codes/models/vqvae/gumbel_quantizer.py index 21937cbf..1975ac19 100644 --- a/codes/models/vqvae/gumbel_quantizer.py +++ b/codes/models/vqvae/gumbel_quantizer.py @@ -3,22 +3,54 @@ import torch.nn as nn import torch.nn.functional as F from torch import einsum +from utils.weight_scheduler import LinearDecayWeightScheduler + class GumbelQuantizer(nn.Module): - def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False, temperature=.9): + def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False): super().__init__() self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1) self.codebook = nn.Embedding(num_tokens, codebook_dim) self.straight_through = straight_through - self.temperature = temperature + self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000) + self.step = 0 + + def get_temperature(self, step): + self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN??? + return self.temperature_scheduler.get_weight_for_step(step) def embed_code(self, codes): return self.codebook(codes) + def gumbel_softmax(self, logits, tau, dim, hard): + gumbels = torch.rand_like(logits) + gumbels = -torch.log(-torch.log(gumbels + 1e-8) + 1e-8) + logits = (logits + gumbels) / tau # ~Gumbel(logits,tau) + y_soft = F.softmax(logits, dim=dim) + + if hard: + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + else: + ret = y_soft + return ret + def forward(self, h): h = h.permute(0,2,1) logits = self.to_logits(h) - logits = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=self.straight_through) + logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through) codes = logits.argmax(dim=1).flatten(1) sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight) - return sampled.permute(0,2,1), 0, codes \ No newline at end of file + return sampled.permute(0,2,1), 0, codes + +if __name__ == '__main__': + from models.diffusion.diffusion_dvae import DiscreteDecoder + j = torch.randn(8,40,1024) + m = GumbelQuantizer(1024, 1024, 4096) + m2 = DiscreteDecoder(1024, (512, 256), 2) + l=m2(m(j)[0].permute(0,2,1)) + mean = 0 + for ls in l: + mean = mean + ls.mean() + mean.backward() \ No newline at end of file diff --git a/codes/scripts/audio/preparation/save_mels_to_disk.py b/codes/scripts/audio/preparation/save_mels_to_disk.py index 37c534b0..c1911031 100644 --- a/codes/scripts/audio/preparation/save_mels_to_disk.py +++ b/codes/scripts/audio/preparation/save_mels_to_disk.py @@ -1,4 +1,5 @@ import argparse +import os import numpy import torch @@ -19,10 +20,12 @@ def main(): mel_inj = MelSpectrogramInjector({'in':'in', 'out':'out'}, {}) audio_loader = AudioAdapter.default() for e, wav_file in enumerate(tqdm(files)): - if e < 272583: + if e < 0: continue print(f"Processing {wav_file}..") outfile = f'{wav_file}.npz' + if os.path.exists(outfile): + continue try: wave, sample_rate = audio_loader.load(wav_file, sample_rate=22050) diff --git a/codes/train.py b/codes/train.py index f4de0a21..c9cc9d33 100644 --- a/codes/train.py +++ b/codes/train.py @@ -95,7 +95,7 @@ class Trainer: seed += self.rank # Different multiprocessing instances should behave differently. util.set_random_seed(seed) - torch.backends.cudnn.benchmark = True + torch.backends.cudnn.benchmark = opt_get(opt, ['cuda_benchmarking_enabled'], True) # torch.backends.cudnn.deterministic = True if opt_get(opt, ['anomaly_detection'], False): torch.autograd.set_detect_anomaly(True) diff --git a/codes/utils/weight_scheduler.py b/codes/utils/weight_scheduler.py index 7a87f58f..60a1b074 100644 --- a/codes/utils/weight_scheduler.py +++ b/codes/utils/weight_scheduler.py @@ -55,11 +55,11 @@ def get_scheduler_for_opt(opt): # Do some testing. if __name__ == "__main__": #sched = SinusoidalWeightScheduler(1, .1, 50, 10) - sched = LinearDecayWeightScheduler(1, 150, .1, 20) + sched = LinearDecayWeightScheduler(10, 5000, .9, 2000) x = [] y = [] - for s in range(200): + for s in range(8000): x.append(s) y.append(sched.get_weight_for_step(s)) plt.plot(x, y)