forked from mrq/DL-Art-School
misc
This commit is contained in:
parent
2997a640b0
commit
15decfdb98
|
@ -4,6 +4,7 @@ from random import randrange
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.utils
|
||||
|
||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
|
@ -38,7 +39,7 @@ class SubBlock(nn.Module):
|
|||
self.mask_initialized = True
|
||||
blk_enc = self.blk_emb_proj(blk_emb)
|
||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
||||
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
||||
ah = ah[:,:,blk_enc.shape[-1]:] # Strip off the blk_emc used for attention and re-align with x.
|
||||
ah = F.gelu(self.attnorm(ah))
|
||||
h = torch.cat([ah, x], dim=1)
|
||||
hf = self.dropout(checkpoint(self.ff, h))
|
||||
|
@ -168,25 +169,21 @@ class TransformerDiffusion(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def input_to_random_resolution_and_window(self, x, x_prior):
|
||||
def input_to_random_resolution_and_window(self, x):
|
||||
"""
|
||||
This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well
|
||||
as caches an internal prior for the rescoped target which will be useud in training.
|
||||
as caches an internal prior for the rescoped target which will be used in training.
|
||||
Args:
|
||||
x: Diffusion target
|
||||
x_prior: Prior input, which is generally just {x}
|
||||
"""
|
||||
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}'
|
||||
resolution = randrange(1, self.resolution_steps)
|
||||
resolution = randrange(0, self.resolution_steps)
|
||||
resolution_scale = 2 ** resolution
|
||||
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
||||
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
||||
s_diff = s.shape[-1] - self.max_window
|
||||
if s_diff > 1:
|
||||
start = randrange(0, s_diff)
|
||||
s = s[:,:,start:start+self.max_window]
|
||||
s_prior = x_prior[:,:,start:start+self.max_window]
|
||||
s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True)
|
||||
s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True)
|
||||
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
|
||||
self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
|
||||
return s
|
||||
|
@ -196,16 +193,18 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
h = x
|
||||
if resolution is None:
|
||||
# This is assumed to be training.
|
||||
assert self.preprocessed is not None, 'Preprocessing function not called.'
|
||||
h = x
|
||||
assert x_prior is None, 'Provided prior will not be used, instead preprocessing output will be used.'
|
||||
h_sub, resolution = self.preprocessed
|
||||
self.preprocessed = None
|
||||
else:
|
||||
h_sub = F.interpolate(x_prior, scale_factor=4, mode='linear', align_corners=True)
|
||||
assert h.shape == h_sub.shape, f'{h.shape} {h_sub.shape}'
|
||||
assert h.shape[-1] > x_prior.shape[-1] * 3.9, f'{h.shape} {x_prior.shape}'
|
||||
h_sub = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
|
||||
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
|
||||
h_sub = self.unconditioned_prior.repeat(x.shape[0], 1, x.shape[-1])
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
else:
|
||||
MIN_COND_LEN = 200
|
||||
MAX_COND_LEN = 1200
|
||||
|
@ -227,8 +226,8 @@ class TransformerDiffusion(nn.Module):
|
|||
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||
res_emb = self.resolution_embed(resolution)
|
||||
blk_emb = torch.cat([time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1)
|
||||
h = torch.cat([h, h_sub], dim=1)
|
||||
|
||||
h = torch.cat([h, h_sub], dim=1)
|
||||
h = self.inp_block(h)
|
||||
for layer in self.layers:
|
||||
h = checkpoint(layer, h, blk_emb)
|
||||
|
|
|
@ -24,7 +24,7 @@ class SubBlock(nn.Module):
|
|||
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
||||
if self.enable_attention_masking:
|
||||
# All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region.
|
||||
self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1)
|
||||
self.mask = build_local_attention_mask(n=4000, l=48, fixed_region=1)
|
||||
self.mask_initialized = False
|
||||
else:
|
||||
self.mask = None
|
||||
|
|
|
@ -340,7 +340,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_instrument_quantizer.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_multilevel_sr.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -1,28 +1,24 @@
|
|||
import functools
|
||||
import os
|
||||
import os
|
||||
import os.path as osp
|
||||
from glob import glob
|
||||
from random import shuffle
|
||||
from time import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import torchvision
|
||||
from pytorch_fid.fid_score import calculate_frechet_distance
|
||||
from torch import distributed
|
||||
from tqdm import tqdm
|
||||
import torch.nn.functional as F
|
||||
|
||||
import trainer.eval.evaluator as evaluator
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
||||
from models.clip.contrastive_audio import ContrastiveAudio
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||
from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
||||
normalize_mel, KmeansQuantizerInjector
|
||||
KmeansQuantizerInjector, normalize_torch_mel
|
||||
from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \
|
||||
get_mel2wav_v3_model, get_ar_prior
|
||||
from utils.util import opt_get, load_model_from_config
|
||||
|
@ -117,7 +113,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
model_kwargs={'codes': mel})
|
||||
gen = pixel_shuffle_1d(gen, self.squeeze_ratio)
|
||||
|
||||
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
|
||||
return gen, real_resampled, normalize_torch_mel(self.spec_fn({'in': gen})['out']), normalize_torch_mel(mel), sample_rate
|
||||
|
||||
def perform_diffusion_from_codes(self, audio, sample_rate=22050):
|
||||
real_resampled = audio
|
||||
|
@ -126,7 +122,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
mel = self.spec_fn({'in': audio})['out']
|
||||
codegen = self.local_modules['codegen'].to(mel.device)
|
||||
codes = codegen.get_codes(mel, project=True)
|
||||
mel_norm = normalize_mel(mel)
|
||||
mel_norm = normalize_torch_mel(mel)
|
||||
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
|
||||
model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
|
||||
|
||||
|
@ -140,27 +136,27 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
||||
|
||||
def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050):
|
||||
real_resampled = audio
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel_norm = normalize_mel(mel)
|
||||
mel_norm = normalize_torch_mel(mel)
|
||||
#def denoising_fn(x):
|
||||
# q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
|
||||
# s = q9.clamp(1, 9999999999)
|
||||
# x = x.clamp(-s, s) / s
|
||||
# return x
|
||||
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, #denoised_fn=denoising_fn, clip_denoised=False,
|
||||
model_kwargs={'truth_mel': mel_norm})
|
||||
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
||||
gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm})
|
||||
|
||||
gen_mel_denorm = denormalize_torch_mel(gen_mel)
|
||||
output_shape = (1,16,audio.shape[-1]//16)
|
||||
self.spec_decoder = self.spec_decoder.to(audio.device)
|
||||
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
|
||||
sampler = self.spectral_diffuser.ddim_sample_loop if self.ddim else self.spectral_diffuser.p_sample_loop
|
||||
gen_wav = sampler(self.spec_decoder, output_shape,
|
||||
model_kwargs={'aligned_conditioning': gen_mel_denorm})
|
||||
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
||||
|
||||
real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
|
||||
real_wav = sampler(self.spec_decoder, output_shape,
|
||||
model_kwargs={'aligned_conditioning': mel})
|
||||
real_wav = pixel_shuffle_1d(real_wav, 16)
|
||||
|
||||
|
@ -170,7 +166,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
audio = audio.unsqueeze(0)
|
||||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel_norm = normalize_mel(mel)
|
||||
mel_norm = normalize_torch_mel(mel)
|
||||
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
|
||||
|
||||
# 1. Generate the cheater latent using the input as a reference.
|
||||
|
@ -203,7 +199,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
audio = audio.unsqueeze(0)
|
||||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel_norm = normalize_mel(mel)
|
||||
mel_norm = normalize_torch_mel(mel)
|
||||
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
|
||||
cheater_codes = self.kmeans_inj({'in': cheater})['out']
|
||||
ar_latent = self.local_modules['ar_prior'].to(audio.device)(cheater_codes, cheater, return_latent=True)
|
||||
|
@ -233,16 +229,17 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
def perform_chained_sr(self, audio, sample_rate=22050):
|
||||
audio = audio.unsqueeze(0)
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel_norm = normalize_mel(mel)
|
||||
mel_norm = normalize_torch_mel(mel)
|
||||
#mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window.
|
||||
conditioning = mel_norm[:,:,:1200]
|
||||
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True)
|
||||
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
||||
stage1_shape = (1, 256, downsampled.shape[-1]*4)
|
||||
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
||||
# Chain super-sampling using 2 stages.
|
||||
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([2], device=audio.device),
|
||||
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
|
||||
'x_prior': downsampled,
|
||||
'conditioning_input': conditioning})
|
||||
stage2 = sampler(self.model, audio.shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
|
||||
stage2 = sampler(self.model, mel.shape, model_kwargs={'resolution': torch.tensor([0], device=audio.device),
|
||||
'x_prior': stage1,
|
||||
'conditioning_input': conditioning})
|
||||
# Decode into waveform.
|
||||
|
@ -328,17 +325,17 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
|
||||
also_load_savepoint=False, strict_load=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator_fixed.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\22000_generator.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 64, # basis: 192
|
||||
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': False,
|
||||
'diffusion_steps': 128, # basis: 192
|
||||
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr',
|
||||
#'causal': True, 'causal_slope': 4,
|
||||
#'partial_low': 128, 'partial_high': 192
|
||||
}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 10, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
fds = []
|
||||
for i in range(2):
|
||||
|
|
|
@ -80,7 +80,7 @@ class GaussianDiffusionInjector(Injector):
|
|||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
hq = state[self.input]
|
||||
assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}"
|
||||
assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}"
|
||||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
|
||||
|
@ -90,7 +90,7 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
||||
if self.preprocess_fn is not None:
|
||||
hq = getattr(gen.module, self.preprocess_fn)(hq, **model_inputs)
|
||||
hq = getattr(gen.module, self.preprocess_fn)(hq)
|
||||
|
||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||
if self.causal_mode:
|
||||
|
|
Loading…
Reference in New Issue
Block a user