This commit is contained in:
James Betker 2022-07-20 10:19:02 -06:00
parent 2997a640b0
commit 15decfdb98
5 changed files with 39 additions and 43 deletions

View File

@ -4,6 +4,7 @@ from random import randrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
@ -38,7 +39,7 @@ class SubBlock(nn.Module):
self.mask_initialized = True
blk_enc = self.blk_emb_proj(blk_emb)
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
ah = ah[:,:,blk_enc.shape[-1]:] # Strip off the blk_emc used for attention and re-align with x.
ah = F.gelu(self.attnorm(ah))
h = torch.cat([ah, x], dim=1)
hf = self.dropout(checkpoint(self.ff, h))
@ -168,25 +169,21 @@ class TransformerDiffusion(nn.Module):
}
return groups
def input_to_random_resolution_and_window(self, x, x_prior):
def input_to_random_resolution_and_window(self, x):
"""
This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well
as caches an internal prior for the rescoped target which will be useud in training.
as caches an internal prior for the rescoped target which will be used in training.
Args:
x: Diffusion target
x_prior: Prior input, which is generally just {x}
"""
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}'
resolution = randrange(1, self.resolution_steps)
resolution = randrange(0, self.resolution_steps)
resolution_scale = 2 ** resolution
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
s_diff = s.shape[-1] - self.max_window
if s_diff > 1:
start = randrange(0, s_diff)
s = s[:,:,start:start+self.max_window]
s_prior = x_prior[:,:,start:start+self.max_window]
s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True)
s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True)
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
return s
@ -196,16 +193,18 @@ class TransformerDiffusion(nn.Module):
h = x
if resolution is None:
# This is assumed to be training.
assert self.preprocessed is not None, 'Preprocessing function not called.'
h = x
assert x_prior is None, 'Provided prior will not be used, instead preprocessing output will be used.'
h_sub, resolution = self.preprocessed
self.preprocessed = None
else:
h_sub = F.interpolate(x_prior, scale_factor=4, mode='linear', align_corners=True)
assert h.shape == h_sub.shape, f'{h.shape} {h_sub.shape}'
assert h.shape[-1] > x_prior.shape[-1] * 3.9, f'{h.shape} {x_prior.shape}'
h_sub = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
h_sub = self.unconditioned_prior.repeat(x.shape[0], 1, x.shape[-1])
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
else:
MIN_COND_LEN = 200
MAX_COND_LEN = 1200
@ -227,8 +226,8 @@ class TransformerDiffusion(nn.Module):
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
res_emb = self.resolution_embed(resolution)
blk_emb = torch.cat([time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1)
h = torch.cat([h, h_sub], dim=1)
h = torch.cat([h, h_sub], dim=1)
h = self.inp_block(h)
for layer in self.layers:
h = checkpoint(layer, h, blk_emb)

View File

@ -24,7 +24,7 @@ class SubBlock(nn.Module):
self.ffnorm = nn.GroupNorm(8, contraction_dim)
if self.enable_attention_masking:
# All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region.
self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1)
self.mask = build_local_attention_mask(n=4000, l=48, fixed_region=1)
self.mask_initialized = False
else:
self.mask = None

View File

@ -340,7 +340,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_instrument_quantizer.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_multilevel_sr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -1,28 +1,24 @@
import functools
import os
import os
import os.path as osp
from glob import glob
from random import shuffle
from time import time
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import torchvision
from pytorch_fid.fid_score import calculate_frechet_distance
from torch import distributed
from tqdm import tqdm
import torch.nn.functional as F
import trainer.eval.evaluator as evaluator
from data.audio.unsupervised_audio_dataset import load_audio
from models.audio.mel2vec import ContrastiveTrainingWrapper
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
from models.clip.contrastive_audio import ContrastiveAudio
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import space_timesteps, SpacedDiffusion
from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
normalize_mel, KmeansQuantizerInjector
KmeansQuantizerInjector, normalize_torch_mel
from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \
get_mel2wav_v3_model, get_ar_prior
from utils.util import opt_get, load_model_from_config
@ -117,7 +113,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
model_kwargs={'codes': mel})
gen = pixel_shuffle_1d(gen, self.squeeze_ratio)
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
return gen, real_resampled, normalize_torch_mel(self.spec_fn({'in': gen})['out']), normalize_torch_mel(mel), sample_rate
def perform_diffusion_from_codes(self, audio, sample_rate=22050):
real_resampled = audio
@ -126,7 +122,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
mel = self.spec_fn({'in': audio})['out']
codegen = self.local_modules['codegen'].to(mel.device)
codes = codegen.get_codes(mel, project=True)
mel_norm = normalize_mel(mel)
mel_norm = normalize_torch_mel(mel)
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
@ -140,27 +136,27 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050):
real_resampled = audio
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
mel_norm = normalize_torch_mel(mel)
#def denoising_fn(x):
# q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
# s = q9.clamp(1, 9999999999)
# x = x.clamp(-s, s) / s
# return x
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, #denoised_fn=denoising_fn, clip_denoised=False,
model_kwargs={'truth_mel': mel_norm})
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm})
gen_mel_denorm = denormalize_torch_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
sampler = self.spectral_diffuser.ddim_sample_loop if self.ddim else self.spectral_diffuser.p_sample_loop
gen_wav = sampler(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
real_wav = sampler(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': mel})
real_wav = pixel_shuffle_1d(real_wav, 16)
@ -170,7 +166,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
mel_norm = normalize_torch_mel(mel)
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
# 1. Generate the cheater latent using the input as a reference.
@ -203,7 +199,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
mel_norm = normalize_torch_mel(mel)
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
cheater_codes = self.kmeans_inj({'in': cheater})['out']
ar_latent = self.local_modules['ar_prior'].to(audio.device)(cheater_codes, cheater, return_latent=True)
@ -233,16 +229,17 @@ class MusicDiffusionFid(evaluator.Evaluator):
def perform_chained_sr(self, audio, sample_rate=22050):
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
mel_norm = normalize_torch_mel(mel)
#mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window.
conditioning = mel_norm[:,:,:1200]
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True)
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
stage1_shape = (1, 256, downsampled.shape[-1]*4)
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
# Chain super-sampling using 2 stages.
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([2], device=audio.device),
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
'x_prior': downsampled,
'conditioning_input': conditioning})
stage2 = sampler(self.model, audio.shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
stage2 = sampler(self.model, mel.shape, model_kwargs={'resolution': torch.tensor([0], device=audio.device),
'x_prior': stage1,
'conditioning_input': conditioning})
# Decode into waveform.
@ -328,17 +325,17 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator_fixed.pth'
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\22000_generator.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 64, # basis: 192
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': False,
'diffusion_steps': 128, # basis: 192
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr',
#'causal': True, 'causal_slope': 4,
#'partial_low': 128, 'partial_high': 192
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 10, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
fds = []
for i in range(2):

View File

@ -80,7 +80,7 @@ class GaussianDiffusionInjector(Injector):
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
hq = state[self.input]
assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}"
assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}"
with autocast(enabled=self.env['opt']['fp16']):
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
@ -90,7 +90,7 @@ class GaussianDiffusionInjector(Injector):
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
if self.preprocess_fn is not None:
hq = getattr(gen.module, self.preprocess_fn)(hq, **model_inputs)
hq = getattr(gen.module, self.preprocess_fn)(hq)
t, weights = sampler.sample(hq.shape[0], hq.device)
if self.causal_mode: