This commit is contained in:
James Betker 2022-07-20 10:19:02 -06:00
parent 2997a640b0
commit 15decfdb98
5 changed files with 39 additions and 43 deletions

View File

@ -4,6 +4,7 @@ from random import randrange
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.utils
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
@ -38,7 +39,7 @@ class SubBlock(nn.Module):
self.mask_initialized = True self.mask_initialized = True
blk_enc = self.blk_emb_proj(blk_emb) blk_enc = self.blk_emb_proj(blk_emb)
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask)) ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x. ah = ah[:,:,blk_enc.shape[-1]:] # Strip off the blk_emc used for attention and re-align with x.
ah = F.gelu(self.attnorm(ah)) ah = F.gelu(self.attnorm(ah))
h = torch.cat([ah, x], dim=1) h = torch.cat([ah, x], dim=1)
hf = self.dropout(checkpoint(self.ff, h)) hf = self.dropout(checkpoint(self.ff, h))
@ -168,25 +169,21 @@ class TransformerDiffusion(nn.Module):
} }
return groups return groups
def input_to_random_resolution_and_window(self, x, x_prior): def input_to_random_resolution_and_window(self, x):
""" """
This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well
as caches an internal prior for the rescoped target which will be useud in training. as caches an internal prior for the rescoped target which will be used in training.
Args: Args:
x: Diffusion target x: Diffusion target
x_prior: Prior input, which is generally just {x}
""" """
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}' resolution = randrange(0, self.resolution_steps)
resolution = randrange(1, self.resolution_steps)
resolution_scale = 2 ** resolution resolution_scale = 2 ** resolution
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True) s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
s_diff = s.shape[-1] - self.max_window s_diff = s.shape[-1] - self.max_window
if s_diff > 1: if s_diff > 1:
start = randrange(0, s_diff) start = randrange(0, s_diff)
s = s[:,:,start:start+self.max_window] s = s[:,:,start:start+self.max_window]
s_prior = x_prior[:,:,start:start+self.max_window] s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True)
s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True)
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
return s return s
@ -196,16 +193,18 @@ class TransformerDiffusion(nn.Module):
h = x h = x
if resolution is None: if resolution is None:
# This is assumed to be training.
assert self.preprocessed is not None, 'Preprocessing function not called.' assert self.preprocessed is not None, 'Preprocessing function not called.'
h = x assert x_prior is None, 'Provided prior will not be used, instead preprocessing output will be used.'
h_sub, resolution = self.preprocessed h_sub, resolution = self.preprocessed
self.preprocessed = None self.preprocessed = None
else: else:
h_sub = F.interpolate(x_prior, scale_factor=4, mode='linear', align_corners=True) assert h.shape[-1] > x_prior.shape[-1] * 3.9, f'{h.shape} {x_prior.shape}'
assert h.shape == h_sub.shape, f'{h.shape} {h_sub.shape}' h_sub = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
if conditioning_free: if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) h_sub = self.unconditioned_prior.repeat(x.shape[0], 1, x.shape[-1])
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
else: else:
MIN_COND_LEN = 200 MIN_COND_LEN = 200
MAX_COND_LEN = 1200 MAX_COND_LEN = 1200
@ -227,8 +226,8 @@ class TransformerDiffusion(nn.Module):
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
res_emb = self.resolution_embed(resolution) res_emb = self.resolution_embed(resolution)
blk_emb = torch.cat([time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1) blk_emb = torch.cat([time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1)
h = torch.cat([h, h_sub], dim=1)
h = torch.cat([h, h_sub], dim=1)
h = self.inp_block(h) h = self.inp_block(h)
for layer in self.layers: for layer in self.layers:
h = checkpoint(layer, h, blk_emb) h = checkpoint(layer, h, blk_emb)

View File

@ -24,7 +24,7 @@ class SubBlock(nn.Module):
self.ffnorm = nn.GroupNorm(8, contraction_dim) self.ffnorm = nn.GroupNorm(8, contraction_dim)
if self.enable_attention_masking: if self.enable_attention_masking:
# All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region. # All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region.
self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1) self.mask = build_local_attention_mask(n=4000, l=48, fixed_region=1)
self.mask_initialized = False self.mask_initialized = False
else: else:
self.mask = None self.mask = None

View File

@ -340,7 +340,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_instrument_quantizer.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_multilevel_sr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)

View File

@ -1,28 +1,24 @@
import functools import os
import os import os
import os.path as osp import os.path as osp
from glob import glob from glob import glob
from random import shuffle
from time import time
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torchaudio import torchaudio
import torchvision import torchvision
from pytorch_fid.fid_score import calculate_frechet_distance from pytorch_fid.fid_score import calculate_frechet_distance
from torch import distributed from torch import distributed
from tqdm import tqdm from tqdm import tqdm
import torch.nn.functional as F
import trainer.eval.evaluator as evaluator import trainer.eval.evaluator as evaluator
from data.audio.unsupervised_audio_dataset import load_audio from data.audio.unsupervised_audio_dataset import load_audio
from models.audio.mel2vec import ContrastiveTrainingWrapper
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
from models.clip.contrastive_audio import ContrastiveAudio from models.clip.contrastive_audio import ContrastiveAudio
from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import space_timesteps, SpacedDiffusion from models.diffusion.respace import space_timesteps, SpacedDiffusion
from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \ from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
normalize_mel, KmeansQuantizerInjector KmeansQuantizerInjector, normalize_torch_mel
from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \ from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \
get_mel2wav_v3_model, get_ar_prior get_mel2wav_v3_model, get_ar_prior
from utils.util import opt_get, load_model_from_config from utils.util import opt_get, load_model_from_config
@ -117,7 +113,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
model_kwargs={'codes': mel}) model_kwargs={'codes': mel})
gen = pixel_shuffle_1d(gen, self.squeeze_ratio) gen = pixel_shuffle_1d(gen, self.squeeze_ratio)
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate return gen, real_resampled, normalize_torch_mel(self.spec_fn({'in': gen})['out']), normalize_torch_mel(mel), sample_rate
def perform_diffusion_from_codes(self, audio, sample_rate=22050): def perform_diffusion_from_codes(self, audio, sample_rate=22050):
real_resampled = audio real_resampled = audio
@ -126,7 +122,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
mel = self.spec_fn({'in': audio})['out'] mel = self.spec_fn({'in': audio})['out']
codegen = self.local_modules['codegen'].to(mel.device) codegen = self.local_modules['codegen'].to(mel.device)
codes = codegen.get_codes(mel, project=True) codes = codegen.get_codes(mel, project=True)
mel_norm = normalize_mel(mel) mel_norm = normalize_torch_mel(mel)
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])}) model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
@ -140,27 +136,27 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050): def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050):
real_resampled = audio
audio = audio.unsqueeze(0) audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out'] mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel) mel_norm = normalize_torch_mel(mel)
#def denoising_fn(x): #def denoising_fn(x):
# q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1) # q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
# s = q9.clamp(1, 9999999999) # s = q9.clamp(1, 9999999999)
# x = x.clamp(-s, s) / s # x = x.clamp(-s, s) / s
# return x # return x
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, #denoised_fn=denoising_fn, clip_denoised=False, sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
model_kwargs={'truth_mel': mel_norm}) gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm})
gen_mel_denorm = denormalize_torch_mel(gen_mel) gen_mel_denorm = denormalize_torch_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16) output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device) self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, sampler = self.spectral_diffuser.ddim_sample_loop if self.ddim else self.spectral_diffuser.p_sample_loop
gen_wav = sampler(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': gen_mel_denorm}) model_kwargs={'aligned_conditioning': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16) gen_wav = pixel_shuffle_1d(gen_wav, 16)
real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, real_wav = sampler(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': mel}) model_kwargs={'aligned_conditioning': mel})
real_wav = pixel_shuffle_1d(real_wav, 16) real_wav = pixel_shuffle_1d(real_wav, 16)
@ -170,7 +166,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
audio = audio.unsqueeze(0) audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out'] mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel) mel_norm = normalize_torch_mel(mel)
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm) cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
# 1. Generate the cheater latent using the input as a reference. # 1. Generate the cheater latent using the input as a reference.
@ -203,7 +199,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
audio = audio.unsqueeze(0) audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out'] mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel) mel_norm = normalize_torch_mel(mel)
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm) cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
cheater_codes = self.kmeans_inj({'in': cheater})['out'] cheater_codes = self.kmeans_inj({'in': cheater})['out']
ar_latent = self.local_modules['ar_prior'].to(audio.device)(cheater_codes, cheater, return_latent=True) ar_latent = self.local_modules['ar_prior'].to(audio.device)(cheater_codes, cheater, return_latent=True)
@ -233,16 +229,17 @@ class MusicDiffusionFid(evaluator.Evaluator):
def perform_chained_sr(self, audio, sample_rate=22050): def perform_chained_sr(self, audio, sample_rate=22050):
audio = audio.unsqueeze(0) audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out'] mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel) mel_norm = normalize_torch_mel(mel)
#mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window.
conditioning = mel_norm[:,:,:1200] conditioning = mel_norm[:,:,:1200]
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True) downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True)
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
stage1_shape = (1, 256, downsampled.shape[-1]*4) stage1_shape = (1, 256, downsampled.shape[-1]*4)
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
# Chain super-sampling using 2 stages. # Chain super-sampling using 2 stages.
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([2], device=audio.device), stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
'x_prior': downsampled, 'x_prior': downsampled,
'conditioning_input': conditioning}) 'conditioning_input': conditioning})
stage2 = sampler(self.model, audio.shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device), stage2 = sampler(self.model, mel.shape, model_kwargs={'resolution': torch.tensor([0], device=audio.device),
'x_prior': stage1, 'x_prior': stage1,
'conditioning_input': conditioning}) 'conditioning_input': conditioning})
# Decode into waveform. # Decode into waveform.
@ -328,17 +325,17 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__': if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
also_load_savepoint=False, strict_load=False, also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator_fixed.pth' load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\22000_generator.pth'
).cuda() ).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 64, # basis: 192 'diffusion_steps': 128, # basis: 192
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': False, 'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr', 'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr',
#'causal': True, 'causal_slope': 4, #'causal': True, 'causal_slope': 4,
#'partial_low': 128, 'partial_high': 192 #'partial_low': 128, 'partial_high': 192
} }
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 10, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env) eval = MusicDiffusionFid(diffusion, opt_eval, env)
fds = [] fds = []
for i in range(2): for i in range(2):

View File

@ -80,7 +80,7 @@ class GaussianDiffusionInjector(Injector):
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
hq = state[self.input] hq = state[self.input]
assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}" assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}"
with autocast(enabled=self.env['opt']['fp16']): with autocast(enabled=self.env['opt']['fp16']):
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0): if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
@ -90,7 +90,7 @@ class GaussianDiffusionInjector(Injector):
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
if self.preprocess_fn is not None: if self.preprocess_fn is not None:
hq = getattr(gen.module, self.preprocess_fn)(hq, **model_inputs) hq = getattr(gen.module, self.preprocess_fn)(hq)
t, weights = sampler.sample(hq.shape[0], hq.device) t, weights = sampler.sample(hq.shape[0], hq.device)
if self.causal_mode: if self.causal_mode: