forked from mrq/DL-Art-School
some fixes to mdf to support new archs
This commit is contained in:
parent
13c263e9fb
commit
76464ca063
|
@ -232,7 +232,7 @@ class TransformerDiffusion(nn.Module):
|
|||
assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
|
||||
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1)
|
||||
else:
|
||||
MIN_COND_LEN = 200
|
||||
MAX_COND_LEN = 1200
|
||||
|
|
|
@ -19,7 +19,7 @@ from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
|||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||
from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
||||
KmeansQuantizerInjector, normalize_torch_mel
|
||||
from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \
|
||||
from utils.music_utils import get_music_codegen, get_cheater_decoder, get_cheater_encoder, \
|
||||
get_mel2wav_v3_model, get_ar_prior
|
||||
from utils.util import opt_get, load_model_from_config
|
||||
|
||||
|
@ -58,8 +58,9 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256)
|
||||
self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu')))
|
||||
self.spec_decoder = get_mel2wav_v3_model()
|
||||
|
||||
self.local_modules = {'projector': self.projector}
|
||||
self.local_modules = {'projector': self.projector, 'spec_decoder': self.spec_decoder}
|
||||
if mode == 'spec_decode':
|
||||
self.diffusion_fn = self.perform_diffusion_spec_decode
|
||||
self.squeeze_ratio = opt_eval['squeeze_ratio']
|
||||
|
@ -78,7 +79,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
||||
conditioning_free=False, conditioning_free_k=1)
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
elif 'from_ar_prior' == mode:
|
||||
self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior
|
||||
self.local_modules['cheater_encoder'] = get_cheater_encoder()
|
||||
|
@ -88,12 +88,8 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
conditioning_free=True, conditioning_free_k=1)
|
||||
self.kmeans_inj = KmeansQuantizerInjector({'centroids': '../experiments/music_k_means_centroids.pth', 'in': 'in', 'out': 'out'}, {})
|
||||
self.local_modules['ar_prior'] = get_ar_prior()
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
elif 'chained_sr' == mode:
|
||||
self.diffusion_fn = self.perform_chained_sr
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
self.spec_decoder = get_mel2wav_v3_model()
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
||||
'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
||||
|
||||
|
@ -149,11 +145,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
self.spec_decoder = self.spec_decoder.to(audio.device)
|
||||
sampler = self.spectral_diffuser.ddim_sample_loop if self.ddim else self.spectral_diffuser.p_sample_loop
|
||||
gen_wav = sampler(self.spec_decoder, output_shape,
|
||||
model_kwargs={'aligned_conditioning': gen_mel_denorm})
|
||||
model_kwargs={'codes': gen_mel_denorm})
|
||||
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
||||
|
||||
real_wav = sampler(self.spec_decoder, output_shape,
|
||||
model_kwargs={'aligned_conditioning': mel})
|
||||
model_kwargs={'codes': mel})
|
||||
real_wav = pixel_shuffle_1d(real_wav, 16)
|
||||
|
||||
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
|
||||
|
@ -227,15 +223,15 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel_norm = normalize_torch_mel(mel)
|
||||
conditioning = mel_norm[:,:,:1200]
|
||||
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest')
|
||||
downsampled = F.interpolate(mel_norm, scale_factor=1/4, mode='nearest')
|
||||
stage1_shape = (1, 256, downsampled.shape[-1]*4)
|
||||
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
||||
# Chain super-sampling using 2 stages.
|
||||
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
|
||||
'x_prior': downsampled,
|
||||
'conditioning_input': conditioning})
|
||||
# (Eventually) Chain super-sampling using 2 stages.
|
||||
#stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
|
||||
# 'x_prior': downsampled,
|
||||
# 'conditioning_input': conditioning})
|
||||
stage2 = sampler(self.model, mel.shape, model_kwargs={'resolution': torch.tensor([0], device=audio.device),
|
||||
'x_prior': stage1,
|
||||
'x_prior': downsampled,
|
||||
'conditioning_input': conditioning})
|
||||
# Decode into waveform.
|
||||
output_shape = (1,16,audio.shape[-1]//16)
|
||||
|
@ -318,24 +314,24 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
# For multilevel SR:
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
|
||||
also_load_savepoint=False, strict_load=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\4000_generator.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\6000_generator.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 128, # basis: 192
|
||||
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': True,
|
||||
'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr',
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
# For TFD+cheater trainer
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
|
||||
also_load_savepoint=False, strict_load=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\20000_generator.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\1000_generator.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
|
@ -343,9 +339,9 @@ if __name__ == '__main__':
|
|||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 6, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 7, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
fds = []
|
||||
for i in range(2):
|
||||
|
|
Loading…
Reference in New Issue
Block a user