Add scheduling to quantizer, enable cudnn_benchmarking to be disabled
This commit is contained in:
parent
3e64e847c2
commit
ac57cdc794
|
@ -1,5 +1,7 @@
|
|||
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
||||
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample
|
||||
import torch
|
||||
|
@ -253,7 +255,8 @@ class DiffusionDVAE(nn.Module):
|
|||
)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
return {'histogram_codes': self.codes}
|
||||
# Note: this is very poor design, but quantizer.get_temperature not only retrieves the temperature, it also updates the step and thus it is extremely important that this function get called regularly.
|
||||
return {'histogram_codes': self.codes, 'quantizer_temperature': self.quantizer.get_temperature(step)}
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
|
|
|
@ -3,22 +3,54 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
|
||||
from utils.weight_scheduler import LinearDecayWeightScheduler
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False, temperature=.9):
|
||||
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False):
|
||||
super().__init__()
|
||||
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
|
||||
self.codebook = nn.Embedding(num_tokens, codebook_dim)
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temperature
|
||||
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
||||
self.step = 0
|
||||
|
||||
def get_temperature(self, step):
|
||||
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
|
||||
return self.temperature_scheduler.get_weight_for_step(step)
|
||||
|
||||
def embed_code(self, codes):
|
||||
return self.codebook(codes)
|
||||
|
||||
def gumbel_softmax(self, logits, tau, dim, hard):
|
||||
gumbels = torch.rand_like(logits)
|
||||
gumbels = -torch.log(-torch.log(gumbels + 1e-8) + 1e-8)
|
||||
logits = (logits + gumbels) / tau # ~Gumbel(logits,tau)
|
||||
y_soft = F.softmax(logits, dim=dim)
|
||||
|
||||
if hard:
|
||||
index = y_soft.max(dim, keepdim=True)[1]
|
||||
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
|
||||
ret = y_hard - y_soft.detach() + y_soft
|
||||
else:
|
||||
ret = y_soft
|
||||
return ret
|
||||
|
||||
def forward(self, h):
|
||||
h = h.permute(0,2,1)
|
||||
logits = self.to_logits(h)
|
||||
logits = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=self.straight_through)
|
||||
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through)
|
||||
codes = logits.argmax(dim=1).flatten(1)
|
||||
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
|
||||
return sampled.permute(0,2,1), 0, codes
|
||||
return sampled.permute(0,2,1), 0, codes
|
||||
|
||||
if __name__ == '__main__':
|
||||
from models.diffusion.diffusion_dvae import DiscreteDecoder
|
||||
j = torch.randn(8,40,1024)
|
||||
m = GumbelQuantizer(1024, 1024, 4096)
|
||||
m2 = DiscreteDecoder(1024, (512, 256), 2)
|
||||
l=m2(m(j)[0].permute(0,2,1))
|
||||
mean = 0
|
||||
for ls in l:
|
||||
mean = mean + ls.mean()
|
||||
mean.backward()
|
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
@ -19,10 +20,12 @@ def main():
|
|||
mel_inj = MelSpectrogramInjector({'in':'in', 'out':'out'}, {})
|
||||
audio_loader = AudioAdapter.default()
|
||||
for e, wav_file in enumerate(tqdm(files)):
|
||||
if e < 272583:
|
||||
if e < 0:
|
||||
continue
|
||||
print(f"Processing {wav_file}..")
|
||||
outfile = f'{wav_file}.npz'
|
||||
if os.path.exists(outfile):
|
||||
continue
|
||||
|
||||
try:
|
||||
wave, sample_rate = audio_loader.load(wav_file, sample_rate=22050)
|
||||
|
|
|
@ -95,7 +95,7 @@ class Trainer:
|
|||
seed += self.rank # Different multiprocessing instances should behave differently.
|
||||
util.set_random_seed(seed)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.benchmark = opt_get(opt, ['cuda_benchmarking_enabled'], True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
if opt_get(opt, ['anomaly_detection'], False):
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
|
|
@ -55,11 +55,11 @@ def get_scheduler_for_opt(opt):
|
|||
# Do some testing.
|
||||
if __name__ == "__main__":
|
||||
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
|
||||
sched = LinearDecayWeightScheduler(1, 150, .1, 20)
|
||||
sched = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for s in range(200):
|
||||
for s in range(8000):
|
||||
x.append(s)
|
||||
y.append(sched.get_weight_for_step(s))
|
||||
plt.plot(x, y)
|
||||
|
|
Loading…
Reference in New Issue
Block a user