forked from mrq/DL-Art-School
Add scheduling to quantizer, enable cudnn_benchmarking to be disabled
This commit is contained in:
parent
3e64e847c2
commit
ac57cdc794
|
@ -1,5 +1,7 @@
|
||||||
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
|
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
|
||||||
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
|
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
||||||
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
|
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
|
||||||
Downsample, Upsample
|
Downsample, Upsample
|
||||||
import torch
|
import torch
|
||||||
|
@ -253,7 +255,8 @@ class DiffusionDVAE(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
return {'histogram_codes': self.codes}
|
# Note: this is very poor design, but quantizer.get_temperature not only retrieves the temperature, it also updates the step and thus it is extremely important that this function get called regularly.
|
||||||
|
return {'histogram_codes': self.codes, 'quantizer_temperature': self.quantizer.get_temperature(step)}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
|
|
|
@ -3,22 +3,54 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
|
from utils.weight_scheduler import LinearDecayWeightScheduler
|
||||||
|
|
||||||
|
|
||||||
class GumbelQuantizer(nn.Module):
|
class GumbelQuantizer(nn.Module):
|
||||||
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False, temperature=.9):
|
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
|
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
|
||||||
self.codebook = nn.Embedding(num_tokens, codebook_dim)
|
self.codebook = nn.Embedding(num_tokens, codebook_dim)
|
||||||
self.straight_through = straight_through
|
self.straight_through = straight_through
|
||||||
self.temperature = temperature
|
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
|
def get_temperature(self, step):
|
||||||
|
self.step = step # VERY POOR DESIGN. WHEN WILL HE EVER LEARN???
|
||||||
|
return self.temperature_scheduler.get_weight_for_step(step)
|
||||||
|
|
||||||
def embed_code(self, codes):
|
def embed_code(self, codes):
|
||||||
return self.codebook(codes)
|
return self.codebook(codes)
|
||||||
|
|
||||||
|
def gumbel_softmax(self, logits, tau, dim, hard):
|
||||||
|
gumbels = torch.rand_like(logits)
|
||||||
|
gumbels = -torch.log(-torch.log(gumbels + 1e-8) + 1e-8)
|
||||||
|
logits = (logits + gumbels) / tau # ~Gumbel(logits,tau)
|
||||||
|
y_soft = F.softmax(logits, dim=dim)
|
||||||
|
|
||||||
|
if hard:
|
||||||
|
index = y_soft.max(dim, keepdim=True)[1]
|
||||||
|
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
|
||||||
|
ret = y_hard - y_soft.detach() + y_soft
|
||||||
|
else:
|
||||||
|
ret = y_soft
|
||||||
|
return ret
|
||||||
|
|
||||||
def forward(self, h):
|
def forward(self, h):
|
||||||
h = h.permute(0,2,1)
|
h = h.permute(0,2,1)
|
||||||
logits = self.to_logits(h)
|
logits = self.to_logits(h)
|
||||||
logits = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=self.straight_through)
|
logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through)
|
||||||
codes = logits.argmax(dim=1).flatten(1)
|
codes = logits.argmax(dim=1).flatten(1)
|
||||||
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
|
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
|
||||||
return sampled.permute(0,2,1), 0, codes
|
return sampled.permute(0,2,1), 0, codes
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from models.diffusion.diffusion_dvae import DiscreteDecoder
|
||||||
|
j = torch.randn(8,40,1024)
|
||||||
|
m = GumbelQuantizer(1024, 1024, 4096)
|
||||||
|
m2 = DiscreteDecoder(1024, (512, 256), 2)
|
||||||
|
l=m2(m(j)[0].permute(0,2,1))
|
||||||
|
mean = 0
|
||||||
|
for ls in l:
|
||||||
|
mean = mean + ls.mean()
|
||||||
|
mean.backward()
|
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
@ -19,10 +20,12 @@ def main():
|
||||||
mel_inj = MelSpectrogramInjector({'in':'in', 'out':'out'}, {})
|
mel_inj = MelSpectrogramInjector({'in':'in', 'out':'out'}, {})
|
||||||
audio_loader = AudioAdapter.default()
|
audio_loader = AudioAdapter.default()
|
||||||
for e, wav_file in enumerate(tqdm(files)):
|
for e, wav_file in enumerate(tqdm(files)):
|
||||||
if e < 272583:
|
if e < 0:
|
||||||
continue
|
continue
|
||||||
print(f"Processing {wav_file}..")
|
print(f"Processing {wav_file}..")
|
||||||
outfile = f'{wav_file}.npz'
|
outfile = f'{wav_file}.npz'
|
||||||
|
if os.path.exists(outfile):
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
wave, sample_rate = audio_loader.load(wav_file, sample_rate=22050)
|
wave, sample_rate = audio_loader.load(wav_file, sample_rate=22050)
|
||||||
|
|
|
@ -95,7 +95,7 @@ class Trainer:
|
||||||
seed += self.rank # Different multiprocessing instances should behave differently.
|
seed += self.rank # Different multiprocessing instances should behave differently.
|
||||||
util.set_random_seed(seed)
|
util.set_random_seed(seed)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = opt_get(opt, ['cuda_benchmarking_enabled'], True)
|
||||||
# torch.backends.cudnn.deterministic = True
|
# torch.backends.cudnn.deterministic = True
|
||||||
if opt_get(opt, ['anomaly_detection'], False):
|
if opt_get(opt, ['anomaly_detection'], False):
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
|
@ -55,11 +55,11 @@ def get_scheduler_for_opt(opt):
|
||||||
# Do some testing.
|
# Do some testing.
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
|
#sched = SinusoidalWeightScheduler(1, .1, 50, 10)
|
||||||
sched = LinearDecayWeightScheduler(1, 150, .1, 20)
|
sched = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
||||||
|
|
||||||
x = []
|
x = []
|
||||||
y = []
|
y = []
|
||||||
for s in range(200):
|
for s in range(8000):
|
||||||
x.append(s)
|
x.append(s)
|
||||||
y.append(sched.get_weight_for_step(s))
|
y.append(sched.get_weight_for_step(s))
|
||||||
plt.plot(x, y)
|
plt.plot(x, y)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user