audio diffusion fid updates
This commit is contained in:
parent
998c53ad4f
commit
58019a2ce3
|
@ -102,11 +102,25 @@ class MelTextCLIP(nn.Module):
|
||||||
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
||||||
voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage
|
voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage
|
||||||
else:
|
else:
|
||||||
text_mask = None
|
text_mask = torch.ones_like(text.float()).bool()
|
||||||
voice_mask = None
|
voice_mask = torch.ones_like(mel[:,0,:].float()).bool()
|
||||||
|
|
||||||
text_latents = self.get_text_projections(text, text_mask)
|
text_emb = self.text_emb(text)
|
||||||
speech_latents = self.get_speech_projection(mel, voice_mask)
|
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
||||||
|
|
||||||
|
speech_emb = self.speech_enc(mel).permute(0,2,1)
|
||||||
|
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
||||||
|
|
||||||
|
# Only autocast the transformer part. The MEL encoder loses accuracy if you autcast it.
|
||||||
|
with torch.autocast(speech_emb.device.type):
|
||||||
|
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
||||||
|
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
||||||
|
|
||||||
|
text_latents = masked_mean(enc_text, text_mask, dim=1)
|
||||||
|
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
|
||||||
|
|
||||||
|
text_latents = self.to_text_latent(text_latents).float()
|
||||||
|
speech_latents = self.to_speech_latent(speech_latents).float()
|
||||||
|
|
||||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||||
|
|
||||||
|
@ -122,7 +136,6 @@ class MelTextCLIP(nn.Module):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_mel_text_clip(opt_net, opt):
|
def register_mel_text_clip(opt_net, opt):
|
||||||
return MelTextCLIP(**opt_get(opt_net, ['kwargs'], {}))
|
return MelTextCLIP(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
|
@ -162,10 +162,10 @@ if __name__ == '__main__':
|
||||||
from utils.util import load_model_from_config
|
from utils.util import load_model_from_config
|
||||||
|
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator',
|
||||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\39500_generator_ema.pth').cuda()
|
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\47500_generator_ema.pth').cuda()
|
||||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||||
'conditioning_free': True, 'conditioning_free_k': 2,
|
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 202, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 1, 'device': 'cuda', 'opt': {}}
|
||||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
Loading…
Reference in New Issue
Block a user