From 6c8032b4beb5da0b6ffe9e19bb3b71a6b4ac08f2 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 6 May 2022 21:56:49 -0600 Subject: [PATCH] more work --- .../models/audio/music/music_gen_fill_gaps.py | 6 +-- codes/models/clip/contrastive_audio.py | 6 +++ codes/trainer/eval/music_diffusion_fid.py | 54 ++++++++----------- 3 files changed, 30 insertions(+), 36 deletions(-) diff --git a/codes/models/audio/music/music_gen_fill_gaps.py b/codes/models/audio/music/music_gen_fill_gaps.py index 7ea51e72..ec4c6f28 100644 --- a/codes/models/audio/music/music_gen_fill_gaps.py +++ b/codes/models/audio/music/music_gen_fill_gaps.py @@ -117,8 +117,8 @@ class MusicGenerator(nn.Module): layer_drop=.1, unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. # Masking parameters. - time_mask_percent_max=.4, - frequency_mask_percent_max=.4, + frequency_mask_percent_max=0, + time_mask_percent_max=0, ): super().__init__() @@ -172,7 +172,7 @@ class MusicGenerator(nn.Module): def do_masking(self, truth): b, c, s = truth.shape mask = torch.ones_like(truth) - if random.random() < .5: + if self.frequency_mask_percent_mask > 0: # Frequency mask cs = random.randint(0, c-10) ce = min(c-1, cs+random.randint(1, int(self.frequency_mask_percent_mask*c))) diff --git a/codes/models/clip/contrastive_audio.py b/codes/models/clip/contrastive_audio.py index e1aeb915..6bd77e55 100644 --- a/codes/models/clip/contrastive_audio.py +++ b/codes/models/clip/contrastive_audio.py @@ -219,6 +219,12 @@ class ContrastiveAudio(nn.Module): def update_for_step(self, step, __): self.to_latent2.weight.data = self.to_latent2.weight.data * .99 + self.to_latent.weight.data * .01 + def project(self, mel): + h1 = self.emb(mel).permute(0, 2, 1) + h1 = self.transformer(h1) + h1 = self.to_latent(h1) + return h1 + def forward( self, mel_input1, diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 9b756631..ec789a9f 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -17,6 +17,7 @@ from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes from data.audio.unsupervised_audio_dataset import load_audio from data.audio.voice_tokenizer import VoiceBpeTokenizer from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen +from models.clip.contrastive_audio import ContrastiveAudio from models.clip.mel_text_clip import MelTextCLIP from models.audio.tts.tacotron2 import text_to_sequence from models.diffusion.gaussian_diffusion import get_named_beta_schedule @@ -58,7 +59,9 @@ class MusicDiffusionFid(evaluator.Evaluator): num_heads=8, dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0) self.spec_decoder.load_state_dict(torch.load('../experiments/music_waveform_gen.pth', map_location=torch.device('cpu'))) - self.local_modules = {'spec_decoder': self.spec_decoder} + self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256) + #self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu'))) + self.local_modules = {'spec_decoder': self.spec_decoder, 'projector': self.projector} if mode == 'spec_decode': self.diffusion_fn = self.perform_diffusion_spec_decode @@ -127,20 +130,11 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen, real_resampled, sample_rate - def load_projector(self): - # TODO: implement for music. - model = MelTextCLIP(dim_text=512, dim_latent=512, dim_speech=512, num_text_tokens=148, text_enc_depth=8, - text_seq_len=400, text_heads=8, speech_enc_depth=10, speech_heads=8, speech_seq_len=1000, - text_mask_percentage=.15, voice_mask_percentage=.15) - weights = torch.load('../experiments/clip_text_to_voice_for_speech_fid.pth') - model.load_state_dict(weights) - return model - - def project(self, projector, sample, sample_rate): - # TODO: implement for music. + def project(self, sample, sample_rate): sample = torchaudio.functional.resample(sample, sample_rate, 22050) - mel = wav_to_mel(sample) - return projector.get_speech_projection(mel).squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim] + mel = self.spec_fn({'in': sample})['out'] + projection = self.projector.project(mel) + return projection.squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim] def compute_frechet_distance(self, proj1, proj2): # I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY? @@ -156,41 +150,35 @@ class MusicDiffusionFid(evaluator.Evaluator): save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"])) os.makedirs(save_path, exist_ok=True) - #projector = self.load_projector().to(self.env['device']) - #projector.eval() + self.projector = self.projector.to(self.dev) + self.projector.eval() # Attempt to fix the random state as much as possible. RNG state will be restored before returning. rng_state = torch.get_rng_state() torch.manual_seed(5) self.model.eval() - frechet_distance = 0 with torch.no_grad(): gen_projections = [] real_projections = [] for i in tqdm(list(range(0, len(self.data), self.skip))): path = self.data[i + self.env['rank']] audio = load_audio(path, 22050).to(self.dev) - mel = self.spec_fn({'in': audio})['out'] - mel_norm = (mel + mel.min().abs()) - mel_norm = mel_norm / mel_norm.max(dim=-1, keepdim=True).values - torchvision.utils.save_image(mel_norm.unsqueeze(1), 'mel.png') + audio = audio[:, :22050*5] sample, ref, sample_rate = self.diffusion_fn(audio) - #gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory. - #real_projections.append(self.project(projector, ref, sample_rate).cpu()) + gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory. + real_projections.append(self.project(ref, sample_rate).cpu()) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate) - #gen_projections = torch.stack(gen_projections, dim=0) - #real_projections = torch.stack(real_projections, dim=0) - #frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) + gen_projections = torch.stack(gen_projections, dim=0) + real_projections = torch.stack(real_projections, dim=0) + frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) - #if distributed.is_initialized() and distributed.get_world_size() > 1: - # distributed.all_reduce(frechet_distance) - # frechet_distance = frechet_distance / distributed.get_world_size() - # distributed.all_reduce(intelligibility_loss) - # intelligibility_loss = intelligibility_loss / distributed.get_world_size() + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.all_reduce(frechet_distance) + frechet_distance = frechet_distance / distributed.get_world_size()\ self.model.train() torch.set_rng_state(rng_state) @@ -205,8 +193,8 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\5000_generator.pth').cuda() - opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100, + load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\14000_generator.pth').cuda() + opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500, 'conditioning_free': False, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}