diff --git a/codes/models/clip/mel_text_clip.py b/codes/models/clip/mel_text_clip.py index a802f422..dda54203 100644 --- a/codes/models/clip/mel_text_clip.py +++ b/codes/models/clip/mel_text_clip.py @@ -62,6 +62,26 @@ class MelTextCLIP(nn.Module): self.voice_mask_percentage = voice_mask_percentage self.mel_compression = mel_compression + def get_text_projections(self, text, text_mask=None): + if text_mask is None: + text_mask = torch.ones_like(text.float()).bool() + text_emb = self.text_emb(text) + text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=text.device)) + with torch.autocast(text.device.type): + enc_text = self.text_transformer(text_emb, mask=text_mask) + text_latents = masked_mean(enc_text, text_mask, dim=1) + return self.to_text_latent(text_latents).float() + + def get_speech_projection(self, mel, voice_mask=None): + if voice_mask is None: + voice_mask = torch.ones_like(mel[:,0,:].float()).bool() + speech_emb = self.speech_enc(mel).permute(0,2,1) + speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=mel.device)) + with torch.autocast(speech_emb.device.type): + enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) + speech_latents = masked_mean(enc_speech, voice_mask, dim=1) + return self.to_speech_latent(speech_latents).float() + def forward( self, text, @@ -82,25 +102,11 @@ class MelTextCLIP(nn.Module): text_mask = torch.rand_like(text.float()) > self.text_mask_percentage voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage else: - text_mask = torch.ones_like(text.float()).bool() - voice_mask = torch.ones_like(mel[:,0,:].float()).bool() + text_mask = None + voice_mask = None - text_emb = self.text_emb(text) - text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) - - speech_emb = self.speech_enc(mel).permute(0,2,1) - speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) - - # Only autocast the transformer part. The MEL encoder loses accuracy if you autcast it. - with torch.autocast(speech_emb.device.type): - enc_text = self.text_transformer(text_emb, mask=text_mask) - enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) - - text_latents = masked_mean(enc_text, text_mask, dim=1) - speech_latents = masked_mean(enc_speech, voice_mask, dim=1) - - text_latents = self.to_text_latent(text_latents).float() - speech_latents = self.to_speech_latent(speech_latents).float() + text_latents = self.get_text_projections(text, text_mask) + speech_latents = self.get_speech_projection(mel, voice_mask) text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) @@ -116,6 +122,7 @@ class MelTextCLIP(nn.Module): return loss + @register_model def register_mel_text_clip(opt_net, opt): return MelTextCLIP(**opt_get(opt_net, ['kwargs'], {})) diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 9de11787..1773c28f 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -14,7 +14,8 @@ import numpy as np import trainer.eval.evaluator as evaluator from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes from data.audio.unsupervised_audio_dataset import load_audio -from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser +from models.clip.mel_text_clip import MelTextCLIP +from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel from utils.util import ceil_multiple, opt_get @@ -54,10 +55,22 @@ class AudioDiffusionFid(evaluator.Evaluator): 'conditioning_input': real_resampled}) return gen, real_resampled, sample_rate + def load_projector(self): + """ + Builds the CLIP model used to project speech into a latent. This model has fixed parameters and a fixed loading + path for the time being. + """ + model = MelTextCLIP(dim_text=512, dim_latent=512, dim_speech=512, num_text_tokens=148, text_enc_depth=8, + text_seq_len=400, text_heads=8, speech_enc_depth=10, speech_heads=8, speech_seq_len=1000, + text_mask_percentage=.15, voice_mask_percentage=.15) + weights = torch.load('../experiments/clip_text_to_voice_for_speech_fid.pth') + model.load_state_dict(weights) + return model + def project(self, projector, sample, sample_rate): - sample = torchaudio.functional.resample(sample, sample_rate, 16000) - sample = (sample - sample.mean()) / torch.sqrt(sample.var() + 1e-7) - return projector(sample.squeeze(1), output_hidden_states=True).hidden_states[-1].squeeze(0) # Getting rid of the batch dimension means it's just [seq_len,hidden_states] + sample = torchaudio.resample(sample, sample_rate, 22050) + mel = wav_to_mel(sample) + return projector.get_speech_projection(mel).squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim] def compute_frechet_distance(self, proj1, proj2): # I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY? @@ -73,7 +86,7 @@ class AudioDiffusionFid(evaluator.Evaluator): save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"])) os.makedirs(save_path, exist_ok=True) - projector = Wav2Vec2ForCTC.from_pretrained(f"facebook/wav2vec2-large").to(self.dev) + projector = self.load_projector().to(self.env['device']) projector.eval() # Attempt to fix the random state as much as possible. RNG state will be restored before returning. @@ -90,30 +103,30 @@ class AudioDiffusionFid(evaluator.Evaluator): codes = codes.to(self.dev) sample, ref, sample_rate = self.perform_diffusion(audio, codes) - gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory. - real_projections.append(self.project(projector, ref, sample_rate).cpu()) + gen_projections.append(self.project(projector, sample).cpu(), sample_rate) # Store on CPU to avoid wasting GPU memory. + real_projections.append(self.project(projector, ref).cpu(), sample_rate) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate) - gen_projections = torch.cat(gen_projections, dim=0) - real_projections = torch.cat(real_projections, dim=0) - fid = self.compute_frechet_distance(gen_projections, real_projections) + gen_projections = torch.stack(gen_projections, dim=0) + real_projections = torch.stack(real_projections, dim=0) + frechet_distance = self.compute_frechet_distance(gen_projections, real_projections) if distributed.is_initialized() and distributed.get_world_size() > 1: - fid = distributed.all_reduce(fid) / distributed.get_world_size() + frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size() self.model.train() torch.set_rng_state(rng_state) - return {"fid": fid} + return {"frechet_distance": frechet_distance} if __name__ == '__main__': from utils.util import load_model_from_config - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts5_medium.yml', 'generator', - also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\73000_generator_ema.pth').cuda() - opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda'} + diffusion = load_model_from_config('X:\\dlas\\experiments\\sweep_diffusion_tts6\\baseline\\train_diffusion_tts6.yml', 'generator', + also_load_savepoint=False, load_path='X:\\dlas\\experiments\\sweep_diffusion_tts6\\baseline\\models\\102000_generator_ema.pth').cuda() + opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50, 'diffusion_schedule': 'linear'} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda', 'opt': {}} eval = AudioDiffusionFid(diffusion, opt_eval, env) eval.perform_eval() \ No newline at end of file