Reformat mel_text_clip for use in eval

This commit is contained in:
James Betker 2022-02-19 20:37:26 -07:00
parent bcba65c539
commit 7b12799370
2 changed files with 54 additions and 34 deletions

View File

@ -62,6 +62,26 @@ class MelTextCLIP(nn.Module):
self.voice_mask_percentage = voice_mask_percentage
self.mel_compression = mel_compression
def get_text_projections(self, text, text_mask=None):
if text_mask is None:
text_mask = torch.ones_like(text.float()).bool()
text_emb = self.text_emb(text)
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=text.device))
with torch.autocast(text.device.type):
enc_text = self.text_transformer(text_emb, mask=text_mask)
text_latents = masked_mean(enc_text, text_mask, dim=1)
return self.to_text_latent(text_latents).float()
def get_speech_projection(self, mel, voice_mask=None):
if voice_mask is None:
voice_mask = torch.ones_like(mel[:,0,:].float()).bool()
speech_emb = self.speech_enc(mel).permute(0,2,1)
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=mel.device))
with torch.autocast(speech_emb.device.type):
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
return self.to_speech_latent(speech_latents).float()
def forward(
self,
text,
@ -82,25 +102,11 @@ class MelTextCLIP(nn.Module):
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage
else:
text_mask = torch.ones_like(text.float()).bool()
voice_mask = torch.ones_like(mel[:,0,:].float()).bool()
text_mask = None
voice_mask = None
text_emb = self.text_emb(text)
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
speech_emb = self.speech_enc(mel).permute(0,2,1)
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
# Only autocast the transformer part. The MEL encoder loses accuracy if you autcast it.
with torch.autocast(speech_emb.device.type):
enc_text = self.text_transformer(text_emb, mask=text_mask)
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
text_latents = masked_mean(enc_text, text_mask, dim=1)
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
text_latents = self.to_text_latent(text_latents).float()
speech_latents = self.to_speech_latent(speech_latents).float()
text_latents = self.get_text_projections(text, text_mask)
speech_latents = self.get_speech_projection(mel, voice_mask)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
@ -116,6 +122,7 @@ class MelTextCLIP(nn.Module):
return loss
@register_model
def register_mel_text_clip(opt_net, opt):
return MelTextCLIP(**opt_get(opt_net, ['kwargs'], {}))

View File

@ -14,7 +14,8 @@ import numpy as np
import trainer.eval.evaluator as evaluator
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
from data.audio.unsupervised_audio_dataset import load_audio
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser
from models.clip.mel_text_clip import MelTextCLIP
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel
from utils.util import ceil_multiple, opt_get
@ -54,10 +55,22 @@ class AudioDiffusionFid(evaluator.Evaluator):
'conditioning_input': real_resampled})
return gen, real_resampled, sample_rate
def load_projector(self):
"""
Builds the CLIP model used to project speech into a latent. This model has fixed parameters and a fixed loading
path for the time being.
"""
model = MelTextCLIP(dim_text=512, dim_latent=512, dim_speech=512, num_text_tokens=148, text_enc_depth=8,
text_seq_len=400, text_heads=8, speech_enc_depth=10, speech_heads=8, speech_seq_len=1000,
text_mask_percentage=.15, voice_mask_percentage=.15)
weights = torch.load('../experiments/clip_text_to_voice_for_speech_fid.pth')
model.load_state_dict(weights)
return model
def project(self, projector, sample, sample_rate):
sample = torchaudio.functional.resample(sample, sample_rate, 16000)
sample = (sample - sample.mean()) / torch.sqrt(sample.var() + 1e-7)
return projector(sample.squeeze(1), output_hidden_states=True).hidden_states[-1].squeeze(0) # Getting rid of the batch dimension means it's just [seq_len,hidden_states]
sample = torchaudio.resample(sample, sample_rate, 22050)
mel = wav_to_mel(sample)
return projector.get_speech_projection(mel).squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim]
def compute_frechet_distance(self, proj1, proj2):
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
@ -73,7 +86,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
os.makedirs(save_path, exist_ok=True)
projector = Wav2Vec2ForCTC.from_pretrained(f"facebook/wav2vec2-large").to(self.dev)
projector = self.load_projector().to(self.env['device'])
projector.eval()
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
@ -90,30 +103,30 @@ class AudioDiffusionFid(evaluator.Evaluator):
codes = codes.to(self.dev)
sample, ref, sample_rate = self.perform_diffusion(audio, codes)
gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
real_projections.append(self.project(projector, ref, sample_rate).cpu())
gen_projections.append(self.project(projector, sample).cpu(), sample_rate) # Store on CPU to avoid wasting GPU memory.
real_projections.append(self.project(projector, ref).cpu(), sample_rate)
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate)
gen_projections = torch.cat(gen_projections, dim=0)
real_projections = torch.cat(real_projections, dim=0)
fid = self.compute_frechet_distance(gen_projections, real_projections)
gen_projections = torch.stack(gen_projections, dim=0)
real_projections = torch.stack(real_projections, dim=0)
frechet_distance = self.compute_frechet_distance(gen_projections, real_projections)
if distributed.is_initialized() and distributed.get_world_size() > 1:
fid = distributed.all_reduce(fid) / distributed.get_world_size()
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
self.model.train()
torch.set_rng_state(rng_state)
return {"fid": fid}
return {"frechet_distance": frechet_distance}
if __name__ == '__main__':
from utils.util import load_model_from_config
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts5_medium.yml', 'generator',
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\73000_generator_ema.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda'}
diffusion = load_model_from_config('X:\\dlas\\experiments\\sweep_diffusion_tts6\\baseline\\train_diffusion_tts6.yml', 'generator',
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\sweep_diffusion_tts6\\baseline\\models\\102000_generator_ema.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50, 'diffusion_schedule': 'linear'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda', 'opt': {}}
eval = AudioDiffusionFid(diffusion, opt_eval, env)
eval.perform_eval()