audio diffusion fid updates

This commit is contained in:
James Betker 2022-03-03 21:53:32 -07:00
parent 998c53ad4f
commit 58019a2ce3
2 changed files with 21 additions and 8 deletions

View File

@ -102,11 +102,25 @@ class MelTextCLIP(nn.Module):
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage
else: else:
text_mask = None text_mask = torch.ones_like(text.float()).bool()
voice_mask = None voice_mask = torch.ones_like(mel[:,0,:].float()).bool()
text_latents = self.get_text_projections(text, text_mask) text_emb = self.text_emb(text)
speech_latents = self.get_speech_projection(mel, voice_mask) text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
speech_emb = self.speech_enc(mel).permute(0,2,1)
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
# Only autocast the transformer part. The MEL encoder loses accuracy if you autcast it.
with torch.autocast(speech_emb.device.type):
enc_text = self.text_transformer(text_emb, mask=text_mask)
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
text_latents = masked_mean(enc_text, text_mask, dim=1)
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
text_latents = self.to_text_latent(text_latents).float()
speech_latents = self.to_speech_latent(speech_latents).float()
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
@ -122,7 +136,6 @@ class MelTextCLIP(nn.Module):
return loss return loss
@register_model @register_model
def register_mel_text_clip(opt_net, opt): def register_mel_text_clip(opt_net, opt):
return MelTextCLIP(**opt_get(opt_net, ['kwargs'], {})) return MelTextCLIP(**opt_get(opt_net, ['kwargs'], {}))

View File

@ -162,10 +162,10 @@ if __name__ == '__main__':
from utils.util import load_model_from_config from utils.util import load_model_from_config
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator',
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\39500_generator_ema.pth').cuda() also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\47500_generator_ema.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100, opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
'conditioning_free': True, 'conditioning_free_k': 2, 'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'} 'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 202, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 1, 'device': 'cuda', 'opt': {}}
eval = AudioDiffusionFid(diffusion, opt_eval, env) eval = AudioDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())