diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index beb79cdb..ed2bf56b 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -57,14 +57,6 @@ class MusicDiffusionFid(evaluator.Evaluator): if mode == 'spec_decode': self.diffusion_fn = self.perform_diffusion_spec_decode - elif 'gap_fill_' in mode: - self.diffusion_fn = self.perform_diffusion_gap_fill - if '_freq' in mode: - self.gap_gen_fn = self.gen_freq_gap - else: - self.gap_gen_fn = self.gen_time_gap - elif 'rerender' in mode: - self.diffusion_fn = self.perform_rerender elif 'from_codes' == mode: self.diffusion_fn = self.perform_diffusion_from_codes self.local_modules['codegen'] = get_music_codegen() @@ -88,75 +80,6 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate - def gen_freq_gap(self, mel, band_range=(60,100)): - gap_start, gap_end = band_range - mask = torch.ones_like(mel) - mask[:, gap_start:gap_end] = 0 - return mel * mask, mask - - def gen_time_gap(self, mel): - mask = torch.ones_like(mel) - mask[:, :, 86*4:86*6] = 0 - return mel * mask, mask - - def perform_diffusion_gap_fill(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio - audio = audio.unsqueeze(0) - - # Fetch the MEL and mask out the requested bands. - mel = self.spec_fn({'in': audio})['out'] - mel = normalize_mel(mel) - mel, mask = self.gap_gen_fn(mel) - - # Repair the MEL with the given model. - spec = self.diffuser.p_sample_loop_with_guidance(self.model, mel, mask, model_kwargs={'truth': mel}) - spec = denormalize_mel(spec) - - # Re-convert the resulting MEL back into audio using the spectrogram decoder. - output_shape = (1, 16, audio.shape[-1] // 16) - self.spec_decoder = self.spec_decoder.to(audio.device) - # Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization. - gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, - model_kwargs={'aligned_conditioning': spec}) - gen = pixel_shuffle_1d(gen, 16) - - return gen, real_resampled, normalize_mel(spec), mel, sample_rate - - def perform_rerender(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio - audio = audio.unsqueeze(0) - - # Fetch the MEL and mask out the requested bands. - mel = self.spec_fn({'in': audio})['out'] - mel = normalize_mel(mel) - - segments = [(0,10),(10,25),(25,45),(45,60),(60,80),(80,100),(100,130),(130,170),(170,210),(210,256)] - shuffle(segments) - spec = mel - for i, segment in enumerate(segments): - mel, mask = self.gen_freq_gap(mel, band_range=segment) - # Repair the MEL with the given model. - spec = self.diffuser.p_sample_loop_with_guidance(self.model, spec, mask, model_kwargs={'truth': spec}) - torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, f"{i}_rerender.png") - - spec = denormalize_mel(spec) - - # Re-convert the resulting MEL back into audio using the spectrogram decoder. - output_shape = (1, 16, audio.shape[-1] // 16) - self.spec_decoder = self.spec_decoder.to(audio.device) - # Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization. - gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, noise=torch.zeros(*output_shape, device=audio.device), - model_kwargs={'aligned_conditioning': spec}) - gen = pixel_shuffle_1d(gen, 16) - - return gen, real_resampled, normalize_mel(spec), mel, sample_rate - def perform_diffusion_from_codes(self, audio, sample_rate=22050): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) @@ -164,15 +87,12 @@ class MusicDiffusionFid(evaluator.Evaluator): real_resampled = audio audio = audio.unsqueeze(0) - # Fetch the MEL and mask out the requested bands. mel = self.spec_fn({'in': audio})['out'] codegen = self.local_modules['codegen'].to(mel.device) codes = codegen.get_codes(mel) mel_norm = normalize_mel(mel) - precomputed = self.model.timestep_independent(aligned_conditioning=codes, conditioning_input=mel[:,:,:112], - expected_seq_len=mel_norm.shape[-1], return_code_pred=False) - gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, noise=torch.zeros_like(mel_norm), - model_kwargs={'precomputed_aligned_embeddings': precomputed}) + gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, + model_kwargs={'codes': codes, 'conditioning_input': mel_norm[:,:,:140]}) gen_mel_denorm = denormalize_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index a2232c99..578c37e3 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -6,6 +6,7 @@ import torchaudio from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat from trainer.inject import Injector +from utils.music_utils import get_music_codegen from utils.util import opt_get, load_model_from_config, pad_or_truncate TACOTRON_MEL_MAX = 2.3143386840820312 @@ -326,15 +327,7 @@ class AudioUnshuffleInjector(Injector): class Mel2vecCodesInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) - for_what = opt_get(opt, ['for'], 'music') - - from models.audio.mel2vec import ContrastiveTrainingWrapper - self.m2v = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, - mask_time_prob=0, - mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4, - disable_custom_linear_init=True, do_reconstruction_loss=True) - self.m2v.load_state_dict(torch.load(f"../experiments/m2v_{for_what}.pth", map_location=torch.device('cpu'))) - self.m2v = self.m2v.eval() + self.m2v = get_music_codegen() del self.m2v.m2v.encoder # This is a big memory sink which will not get used. self.needs_move = True @@ -366,4 +359,14 @@ class ClvpTextInjector(Injector): if self.needs_move: self.clvp = self.clvp.to(codes.device) latents = self.clvp.embed_text(codes) - return {self.output: latents} \ No newline at end of file + return {self.output: latents} + + +class NormalizeMelInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + + def forward(self, state): + mel = state[self.input] + with torch.no_grad(): + return {self.output: normalize_mel(mel)} \ No newline at end of file diff --git a/codes/utils/music_utils.py b/codes/utils/music_utils.py index 64d26f20..24d34e84 100644 --- a/codes/utils/music_utils.py +++ b/codes/utils/music_utils.py @@ -13,8 +13,10 @@ def get_mel2wav_model(): return model def get_music_codegen(): - model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0, - mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True) - model.load_state_dict(torch.load("../experiments/m2v_music.pth", map_location=torch.device('cpu'))) - model.eval() + model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, + mask_time_prob=0, + mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4, + disable_custom_linear_init=True, do_reconstruction_loss=True) + model.load_state_dict(torch.load(f"../experiments/m2v_music.pth", map_location=torch.device('cpu'))) + model = model.eval() return model \ No newline at end of file