From 58019a2ce35e2d3a560e348e9c3f5efd4650789f Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 3 Mar 2022 21:53:32 -0700 Subject: [PATCH] audio diffusion fid updates --- codes/models/clip/mel_text_clip.py | 23 ++++++++++++++++++----- codes/trainer/eval/audio_diffusion_fid.py | 6 +++--- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/codes/models/clip/mel_text_clip.py b/codes/models/clip/mel_text_clip.py index dda54203..a8c1dda5 100644 --- a/codes/models/clip/mel_text_clip.py +++ b/codes/models/clip/mel_text_clip.py @@ -102,11 +102,25 @@ class MelTextCLIP(nn.Module): text_mask = torch.rand_like(text.float()) > self.text_mask_percentage voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage else: - text_mask = None - voice_mask = None + text_mask = torch.ones_like(text.float()).bool() + voice_mask = torch.ones_like(mel[:,0,:].float()).bool() - text_latents = self.get_text_projections(text, text_mask) - speech_latents = self.get_speech_projection(mel, voice_mask) + text_emb = self.text_emb(text) + text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) + + speech_emb = self.speech_enc(mel).permute(0,2,1) + speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + + # Only autocast the transformer part. The MEL encoder loses accuracy if you autcast it. + with torch.autocast(speech_emb.device.type): + enc_text = self.text_transformer(text_emb, mask=text_mask) + enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) + + text_latents = masked_mean(enc_text, text_mask, dim=1) + speech_latents = masked_mean(enc_speech, voice_mask, dim=1) + + text_latents = self.to_text_latent(text_latents).float() + speech_latents = self.to_speech_latent(speech_latents).float() text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) @@ -122,7 +136,6 @@ class MelTextCLIP(nn.Module): return loss - @register_model def register_mel_text_clip(opt_net, opt): return MelTextCLIP(**opt_get(opt_net, ['kwargs'], {})) diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 0be66264..62e4ccbc 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -162,10 +162,10 @@ if __name__ == '__main__': from utils.util import load_model_from_config diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator', - also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\39500_generator_ema.pth').cuda() + also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\47500_generator_ema.pth').cuda() opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100, - 'conditioning_free': True, 'conditioning_free_k': 2, + 'conditioning_free': True, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 202, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 1, 'device': 'cuda', 'opt': {}} eval = AudioDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval()) \ No newline at end of file