Add scheduling to quantizer, enable cudnn_benchmarking to be disabled

This commit is contained in:
James Betker 2021-09-24 17:01:36 -06:00
parent 3e64e847c2
commit ac57cdc794
5 changed files with 47 additions and 9 deletions

View File

@ -1,5 +1,7 @@
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.respace import SpacedDiffusion, space_timesteps
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
Downsample, Upsample
import torch
@ -253,7 +255,8 @@ class DiffusionDVAE(nn.Module):
)
def get_debug_values(self, step, __):
return {'histogram_codes': self.codes}
# Note: this is very poor design, but quantizer.get_temperature not only retrieves the temperature, it also updates the step and thus it is extremely important that this function get called regularly.
return {'histogram_codes': self.codes, 'quantizer_temperature': self.quantizer.get_temperature(step)}
@torch.no_grad()
@eval_decorator

View File

@ -3,22 +3,54 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from utils.weight_scheduler import LinearDecayWeightScheduler
class GumbelQuantizer(nn.Module):
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False, temperature=.9):
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False):
super().__init__()
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
self.codebook = nn.Embedding(num_tokens, codebook_dim)
self.straight_through = straight_through
self.temperature = temperature
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
self.step = 0
def get_temperature(self, step):
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
return self.temperature_scheduler.get_weight_for_step(step)
def embed_code(self, codes):
return self.codebook(codes)
def gumbel_softmax(self, logits, tau, dim, hard):
gumbels = torch.rand_like(logits)
gumbels = -torch.log(-torch.log(gumbels + 1e-8) + 1e-8)
logits = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = F.softmax(logits, dim=dim)
if hard:
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
ret = y_soft
return ret
def forward(self, h):
h = h.permute(0,2,1)
logits = self.to_logits(h)
logits = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=self.straight_through)
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through)
codes = logits.argmax(dim=1).flatten(1)
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
return sampled.permute(0,2,1), 0, codes
return sampled.permute(0,2,1), 0, codes
if __name__ == '__main__':
from models.diffusion.diffusion_dvae import DiscreteDecoder
j = torch.randn(8,40,1024)
m = GumbelQuantizer(1024, 1024, 4096)
m2 = DiscreteDecoder(1024, (512, 256), 2)
l=m2(m(j)[0].permute(0,2,1))
mean = 0
for ls in l:
mean = mean + ls.mean()
mean.backward()

View File

@ -1,4 +1,5 @@
import argparse
import os
import numpy
import torch
@ -19,10 +20,12 @@ def main():
mel_inj = MelSpectrogramInjector({'in':'in', 'out':'out'}, {})
audio_loader = AudioAdapter.default()
for e, wav_file in enumerate(tqdm(files)):
if e < 272583:
if e < 0:
continue
print(f"Processing {wav_file}..")
outfile = f'{wav_file}.npz'
if os.path.exists(outfile):
continue
try:
wave, sample_rate = audio_loader.load(wav_file, sample_rate=22050)

View File

@ -95,7 +95,7 @@ class Trainer:
seed += self.rank # Different multiprocessing instances should behave differently.
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark = opt_get(opt, ['cuda_benchmarking_enabled'], True)
# torch.backends.cudnn.deterministic = True
if opt_get(opt, ['anomaly_detection'], False):
torch.autograd.set_detect_anomaly(True)

View File

@ -55,11 +55,11 @@ def get_scheduler_for_opt(opt):
# Do some testing.
if __name__ == "__main__":
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
sched = LinearDecayWeightScheduler(1, 150, .1, 20)
sched = LinearDecayWeightScheduler(10, 5000, .9, 2000)
x = []
y = []
for s in range(200):
for s in range(8000):
x.append(s)
y.append(sched.get_weight_for_step(s))
plt.plot(x, y)