some fixes to mdf to support new archs

This commit is contained in:
James Betker 2022-07-21 10:55:50 -06:00
parent 13c263e9fb
commit 76464ca063
2 changed files with 18 additions and 22 deletions

View File

@ -232,7 +232,7 @@ class TransformerDiffusion(nn.Module):
assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1)
else:
MIN_COND_LEN = 200
MAX_COND_LEN = 1200

View File

@ -19,7 +19,7 @@ from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import space_timesteps, SpacedDiffusion
from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
KmeansQuantizerInjector, normalize_torch_mel
from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \
from utils.music_utils import get_music_codegen, get_cheater_decoder, get_cheater_encoder, \
get_mel2wav_v3_model, get_ar_prior
from utils.util import opt_get, load_model_from_config
@ -58,8 +58,9 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256)
self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu')))
self.spec_decoder = get_mel2wav_v3_model()
self.local_modules = {'projector': self.projector}
self.local_modules = {'projector': self.projector, 'spec_decoder': self.spec_decoder}
if mode == 'spec_decode':
self.diffusion_fn = self.perform_diffusion_spec_decode
self.squeeze_ratio = opt_eval['squeeze_ratio']
@ -78,7 +79,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=False, conditioning_free_k=1)
self.local_modules['spec_decoder'] = self.spec_decoder
elif 'from_ar_prior' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior
self.local_modules['cheater_encoder'] = get_cheater_encoder()
@ -88,12 +88,8 @@ class MusicDiffusionFid(evaluator.Evaluator):
conditioning_free=True, conditioning_free_k=1)
self.kmeans_inj = KmeansQuantizerInjector({'centroids': '../experiments/music_k_means_centroids.pth', 'in': 'in', 'out': 'out'}, {})
self.local_modules['ar_prior'] = get_ar_prior()
self.local_modules['spec_decoder'] = self.spec_decoder
elif 'chained_sr' == mode:
self.diffusion_fn = self.perform_chained_sr
self.local_modules['spec_decoder'] = self.spec_decoder
self.spec_decoder = get_mel2wav_v3_model()
self.local_modules['spec_decoder'] = self.spec_decoder
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
'normalize': True, 'in': 'in', 'out': 'out'}, {})
@ -149,11 +145,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.spec_decoder = self.spec_decoder.to(audio.device)
sampler = self.spectral_diffuser.ddim_sample_loop if self.ddim else self.spectral_diffuser.p_sample_loop
gen_wav = sampler(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': gen_mel_denorm})
model_kwargs={'codes': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
real_wav = sampler(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': mel})
model_kwargs={'codes': mel})
real_wav = pixel_shuffle_1d(real_wav, 16)
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
@ -227,15 +223,15 @@ class MusicDiffusionFid(evaluator.Evaluator):
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_torch_mel(mel)
conditioning = mel_norm[:,:,:1200]
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest')
downsampled = F.interpolate(mel_norm, scale_factor=1/4, mode='nearest')
stage1_shape = (1, 256, downsampled.shape[-1]*4)
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
# Chain super-sampling using 2 stages.
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
'x_prior': downsampled,
'conditioning_input': conditioning})
# (Eventually) Chain super-sampling using 2 stages.
#stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
# 'x_prior': downsampled,
# 'conditioning_input': conditioning})
stage2 = sampler(self.model, mel.shape, model_kwargs={'resolution': torch.tensor([0], device=audio.device),
'x_prior': stage1,
'x_prior': downsampled,
'conditioning_input': conditioning})
# Decode into waveform.
output_shape = (1,16,audio.shape[-1]//16)
@ -318,24 +314,24 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
"""
# For multilevel SR:
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\4000_generator.pth'
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\6000_generator.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 128, # basis: 192
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': True,
'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr',
}
"""
# For TFD+cheater trainer
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\20000_generator.pth'
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\1000_generator.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
@ -343,9 +339,9 @@ if __name__ == '__main__':
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
}
"""
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 6, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 7, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
fds = []
for i in range(2):