diff --git a/codes/models/audio/music/music_gen_fill_gaps.py b/codes/models/audio/music/music_gen_fill_gaps.py index b2761763..7ea51e72 100644 --- a/codes/models/audio/music/music_gen_fill_gaps.py +++ b/codes/models/audio/music/music_gen_fill_gaps.py @@ -186,15 +186,14 @@ class MusicGenerator(nn.Module): def timestep_independent(self, truth, expected_seq_len, return_code_pred): - code_emb = self.conditioner(truth) + truth_emb = self.conditioner(truth) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1), - code_emb) - expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') - return expanded_code_emb + unconditioned_batches = torch.rand((truth_emb.shape[0], 1, 1), + device=truth_emb.device) < self.unconditioned_percentage + truth_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1), + truth_emb) + return truth_emb def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False): @@ -212,20 +211,21 @@ class MusicGenerator(nn.Module): unused_params = [] if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + truth_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) unused_params.extend(list(self.conditioner.parameters())) else: if precomputed_aligned_embeddings is not None: - code_emb = precomputed_aligned_embeddings + truth_emb = precomputed_aligned_embeddings else: - truth = self.do_masking(truth) - code_emb = self.timestep_independent(truth, x.shape[-1], True) + if self.training: + truth = self.do_masking(truth) + truth_emb = self.timestep_independent(truth, x.shape[-1], True) unused_params.append(self.unconditioned_embedding) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + truth_emb = self.conditioning_timestep_integrator(truth_emb, time_emb) x = self.inp_block(x) - x = torch.cat([x, code_emb], dim=1) + x = torch.cat([x, truth_emb], dim=1) x = self.integrating_conv(x) for i, lyr in enumerate(self.layers): # Do layer drop where applicable. Do not drop first and last layers. diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen.py b/codes/models/audio/music/unet_diffusion_waveform_gen.py index 7aa5bc85..f0ef775f 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen.py @@ -95,7 +95,7 @@ class ResBlock(TimestepBlock): h = self.out_layers(h) return self.skip_connection(x) + h -class DiffusionTts(nn.Module): +class DiffusionWaveformGen(nn.Module): """ The full UNet model with attention and timestep embedding. @@ -465,7 +465,7 @@ class DiffusionTts(nn.Module): @register_model def register_unet_diffusion_waveform_gen(opt_net, opt): - return DiffusionTts(**opt_net['kwargs']) + return DiffusionWaveformGen(**opt_net['kwargs']) if __name__ == '__main__': @@ -473,17 +473,17 @@ if __name__ == '__main__': aligned_latent = torch.randn(2,388,1024) aligned_sequence = torch.randn(2,120,220) ts = torch.LongTensor([600, 600]) - model = DiffusionTts(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], - num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - token_conditioning_resolutions=[1,4,16,64], - attention_resolutions=[], - num_heads=8, - kernel_size=3, - scale_factor=2, - time_embed_dim_multiplier=4, - super_sampling=False, - efficient_convs=False) + model = DiffusionWaveformGen(128, + channel_mult=[1,1.5,2, 3, 4, 6, 8], + num_res_blocks=[2, 2, 2, 2, 2, 2, 1], + token_conditioning_resolutions=[1,4,16,64], + attention_resolutions=[], + num_heads=8, + kernel_size=3, + scale_factor=2, + time_embed_dim_multiplier=4, + super_sampling=False, + efficient_convs=False) # Test with latent aligned conditioning o = model(clip, ts, aligned_latent) # Test with sequence aligned conditioning diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 124adbe0..b5ae28b8 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -17,7 +17,7 @@ from models.clip.mel_text_clip import MelTextCLIP from models.audio.tts.tacotron2 import text_to_sequence from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \ convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel -from trainer.injectors.audio_injectors import denormalize_tacotron_mel +from trainer.injectors.audio_injectors import denormalize_mel from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate @@ -161,7 +161,7 @@ class AudioDiffusionFid(evaluator.Evaluator): model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': univnet_mel}) # denormalize mel - gen_mel = denormalize_tacotron_mel(gen_mel) + gen_mel = denormalize_mel(gen_mel) gen_wav = self.local_modules['vocoder'].inference(gen_mel) real_dec = self.local_modules['vocoder'].inference(univnet_mel) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 9b364982..9b756631 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -16,13 +16,15 @@ import trainer.eval.evaluator as evaluator from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes from data.audio.unsupervised_audio_dataset import load_audio from data.audio.voice_tokenizer import VoiceBpeTokenizer +from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen from models.clip.mel_text_clip import MelTextCLIP from models.audio.tts.tacotron2 import text_to_sequence from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.respace import space_timesteps, SpacedDiffusion from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \ convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel -from trainer.injectors.audio_injectors import denormalize_tacotron_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d +from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \ + normalize_mel from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate @@ -50,15 +52,28 @@ class MusicDiffusionFid(evaluator.Evaluator): conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k) self.dev = self.env['device'] mode = opt_get(opt_eval, ['diffusion_type'], 'tts') - self.local_modules = {} - if mode == 'standard': - self.diffusion_fn = self.perform_diffusion_standard + + self.spec_decoder = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32, + channel_mult=[1,2,3,4], num_res_blocks=[3,3,3,3], token_conditioning_resolutions=[1,4], + num_heads=8, + dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0) + self.spec_decoder.load_state_dict(torch.load('../experiments/music_waveform_gen.pth', map_location=torch.device('cpu'))) + self.local_modules = {'spec_decoder': self.spec_decoder} + + if mode == 'spec_decode': + self.diffusion_fn = self.perform_diffusion_spec_decode + elif 'gap_fill_' in mode: + self.diffusion_fn = self.perform_diffusion_gap_fill + if '_freq' in mode: + self.gap_gen_fn = self.gen_freq_gap + else: + self.gap_gen_fn = self.gen_time_gap self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {}) def load_data(self, path): return list(glob(f'{path}/*.wav')) - def perform_diffusion_standard(self, audio, sample_rate=22050): + def perform_diffusion_spec_decode(self, audio, sample_rate=22050): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) else: @@ -69,7 +84,47 @@ class MusicDiffusionFid(evaluator.Evaluator): gen = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device), model_kwargs={'aligned_conditioning': mel}) gen = pixel_shuffle_1d(gen, 16) - real_resampled = real_resampled + torch.FloatTensor(real_resampled.shape).uniform_(0.0, 1e-5).to(real_resampled.device) + + return gen, real_resampled, sample_rate + + def gen_freq_gap(self, mel, band_range=(130,150)): + gap_start, gap_end = band_range + mel[:, gap_start:gap_end] = 0 + return mel + + def gen_time_gap(self, mel): + mel[:, :, 22050*5:22050*6] = 0 + return mel + + def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)): + if sample_rate != sample_rate: + real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) + else: + real_resampled = audio + audio = audio.unsqueeze(0) + + # Fetch the MEL and mask out the requested bands. + mel = self.spec_fn({'in': audio})['out'] + mel = normalize_mel(mel) + mel = self.gap_gen_fn(mel) + output_shape = (1, mel.shape[1], mel.shape[2]) + + # Repair the MEL with the given model. + spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device), + model_kwargs={'truth': mel}) + import torchvision + torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png') + torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png') + spec = denormalize_mel(spec) + + # Re-convert the resulting MEL back into audio using the spectrogram decoder. + output_shape = (1, 16, audio.shape[-1] // 16) + self.spec_decoder = self.spec_decoder.to(audio.device) + # Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization. + gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, noise=torch.zeros(*output_shape, device=audio.device), + model_kwargs={'aligned_conditioning': spec}) + gen = pixel_shuffle_1d(gen, 16) + return gen, real_resampled, sample_rate def load_projector(self): @@ -148,12 +203,12 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen3.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_waveform_gen3_r1\\models\\10000_generator_ema.pth').cuda() - opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50, + load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\5000_generator.pth').cuda() + opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100, 'conditioning_free': False, 'conditioning_free_k': 1, - 'diffusion_schedule': 'linear', 'diffusion_type': 'standard'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}} + 'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval()) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 554f70ba..a76910da 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -282,6 +282,20 @@ class ConditioningLatentDistributionDivergenceInjector(Injector): return {self.output: mean_loss, self.var_loss_key: var_loss} +class RandomScaleInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.min_samples = opt['min_samples'] + + def forward(self, state): + inp = state[self.input] + if self.min_samples < inp.shape[-1]: + samples = random.randint(self.min_samples, inp.shape[-1]) + start = random.randint(0, inp.shape[-1]-samples) + inp = inp[:, :, start:start+samples] + return {self.output: inp} + + def pixel_shuffle_1d(x, upscale_factor): batch_size, channels, steps = x.size() channels //= upscale_factor