Reformat mel_text_clip for use in eval
This commit is contained in:
parent
bcba65c539
commit
7b12799370
|
@ -62,6 +62,26 @@ class MelTextCLIP(nn.Module):
|
|||
self.voice_mask_percentage = voice_mask_percentage
|
||||
self.mel_compression = mel_compression
|
||||
|
||||
def get_text_projections(self, text, text_mask=None):
|
||||
if text_mask is None:
|
||||
text_mask = torch.ones_like(text.float()).bool()
|
||||
text_emb = self.text_emb(text)
|
||||
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=text.device))
|
||||
with torch.autocast(text.device.type):
|
||||
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
||||
text_latents = masked_mean(enc_text, text_mask, dim=1)
|
||||
return self.to_text_latent(text_latents).float()
|
||||
|
||||
def get_speech_projection(self, mel, voice_mask=None):
|
||||
if voice_mask is None:
|
||||
voice_mask = torch.ones_like(mel[:,0,:].float()).bool()
|
||||
speech_emb = self.speech_enc(mel).permute(0,2,1)
|
||||
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=mel.device))
|
||||
with torch.autocast(speech_emb.device.type):
|
||||
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
||||
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
|
||||
return self.to_speech_latent(speech_latents).float()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text,
|
||||
|
@ -82,25 +102,11 @@ class MelTextCLIP(nn.Module):
|
|||
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
||||
voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage
|
||||
else:
|
||||
text_mask = torch.ones_like(text.float()).bool()
|
||||
voice_mask = torch.ones_like(mel[:,0,:].float()).bool()
|
||||
text_mask = None
|
||||
voice_mask = None
|
||||
|
||||
text_emb = self.text_emb(text)
|
||||
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
||||
|
||||
speech_emb = self.speech_enc(mel).permute(0,2,1)
|
||||
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
||||
|
||||
# Only autocast the transformer part. The MEL encoder loses accuracy if you autcast it.
|
||||
with torch.autocast(speech_emb.device.type):
|
||||
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
||||
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
||||
|
||||
text_latents = masked_mean(enc_text, text_mask, dim=1)
|
||||
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
|
||||
|
||||
text_latents = self.to_text_latent(text_latents).float()
|
||||
speech_latents = self.to_speech_latent(speech_latents).float()
|
||||
text_latents = self.get_text_projections(text, text_mask)
|
||||
speech_latents = self.get_speech_projection(mel, voice_mask)
|
||||
|
||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||
|
||||
|
@ -116,6 +122,7 @@ class MelTextCLIP(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def register_mel_text_clip(opt_net, opt):
|
||||
return MelTextCLIP(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
|
|
@ -14,7 +14,8 @@ import numpy as np
|
|||
import trainer.eval.evaluator as evaluator
|
||||
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser
|
||||
from models.clip.mel_text_clip import MelTextCLIP
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel
|
||||
from utils.util import ceil_multiple, opt_get
|
||||
|
||||
|
||||
|
@ -54,10 +55,22 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
'conditioning_input': real_resampled})
|
||||
return gen, real_resampled, sample_rate
|
||||
|
||||
def load_projector(self):
|
||||
"""
|
||||
Builds the CLIP model used to project speech into a latent. This model has fixed parameters and a fixed loading
|
||||
path for the time being.
|
||||
"""
|
||||
model = MelTextCLIP(dim_text=512, dim_latent=512, dim_speech=512, num_text_tokens=148, text_enc_depth=8,
|
||||
text_seq_len=400, text_heads=8, speech_enc_depth=10, speech_heads=8, speech_seq_len=1000,
|
||||
text_mask_percentage=.15, voice_mask_percentage=.15)
|
||||
weights = torch.load('../experiments/clip_text_to_voice_for_speech_fid.pth')
|
||||
model.load_state_dict(weights)
|
||||
return model
|
||||
|
||||
def project(self, projector, sample, sample_rate):
|
||||
sample = torchaudio.functional.resample(sample, sample_rate, 16000)
|
||||
sample = (sample - sample.mean()) / torch.sqrt(sample.var() + 1e-7)
|
||||
return projector(sample.squeeze(1), output_hidden_states=True).hidden_states[-1].squeeze(0) # Getting rid of the batch dimension means it's just [seq_len,hidden_states]
|
||||
sample = torchaudio.resample(sample, sample_rate, 22050)
|
||||
mel = wav_to_mel(sample)
|
||||
return projector.get_speech_projection(mel).squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim]
|
||||
|
||||
def compute_frechet_distance(self, proj1, proj2):
|
||||
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
||||
|
@ -73,7 +86,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
projector = Wav2Vec2ForCTC.from_pretrained(f"facebook/wav2vec2-large").to(self.dev)
|
||||
projector = self.load_projector().to(self.env['device'])
|
||||
projector.eval()
|
||||
|
||||
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
||||
|
@ -90,30 +103,30 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
codes = codes.to(self.dev)
|
||||
sample, ref, sample_rate = self.perform_diffusion(audio, codes)
|
||||
|
||||
gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
||||
real_projections.append(self.project(projector, ref, sample_rate).cpu())
|
||||
gen_projections.append(self.project(projector, sample).cpu(), sample_rate) # Store on CPU to avoid wasting GPU memory.
|
||||
real_projections.append(self.project(projector, ref).cpu(), sample_rate)
|
||||
|
||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate)
|
||||
gen_projections = torch.cat(gen_projections, dim=0)
|
||||
real_projections = torch.cat(real_projections, dim=0)
|
||||
fid = self.compute_frechet_distance(gen_projections, real_projections)
|
||||
gen_projections = torch.stack(gen_projections, dim=0)
|
||||
real_projections = torch.stack(real_projections, dim=0)
|
||||
frechet_distance = self.compute_frechet_distance(gen_projections, real_projections)
|
||||
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
fid = distributed.all_reduce(fid) / distributed.get_world_size()
|
||||
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
|
||||
|
||||
self.model.train()
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
return {"fid": fid}
|
||||
return {"frechet_distance": frechet_distance}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from utils.util import load_model_from_config
|
||||
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts5_medium.yml', 'generator',
|
||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\73000_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda'}
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\sweep_diffusion_tts6\\baseline\\train_diffusion_tts6.yml', 'generator',
|
||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\sweep_diffusion_tts6\\baseline\\models\\102000_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50, 'diffusion_schedule': 'linear'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
eval.perform_eval()
|
Loading…
Reference in New Issue
Block a user