add support for the original vocoder to audio_diffusion_fid; also add a new "intelligibility" metric
This commit is contained in:
parent
3e5da71b16
commit
c4e4cf91a0
|
@ -47,6 +47,10 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||
if mode == 'tts':
|
||||
self.diffusion_fn = self.perform_diffusion_tts
|
||||
elif mode == 'original_vocoder':
|
||||
self.dvae = load_speech_dvae().to(self.env['device'])
|
||||
self.dvae.eval()
|
||||
self.diffusion_fn = self.perform_original_diffusion_vocoder
|
||||
elif mode == 'vocoder':
|
||||
self.dvae = load_speech_dvae().to(self.env['device'])
|
||||
self.dvae.eval()
|
||||
|
@ -67,6 +71,31 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
'conditioning_input': real_resampled})
|
||||
return gen, real_resampled, sample_rate
|
||||
|
||||
def perform_original_diffusion_vocoder(self, audio, codes, text, sample_rate=11025):
|
||||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
||||
back_to_mel = self.dvae.decode(mel_codes)[0]
|
||||
orig_audio = audio
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
|
||||
output_size = real_resampled.shape[-1]
|
||||
aligned_mel_compression_factor = output_size // back_to_mel.shape[-1]
|
||||
padded_size = ceil_multiple(output_size, 2048)
|
||||
padding_added = padded_size - output_size
|
||||
padding_needed_for_codes = padding_added // aligned_mel_compression_factor
|
||||
if padding_needed_for_codes > 0:
|
||||
back_to_mel = F.pad(back_to_mel, (0, padding_needed_for_codes))
|
||||
output_shape = (1, 1, padded_size)
|
||||
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||
model_kwargs={'spectrogram': back_to_mel,
|
||||
'conditioning_input': orig_audio.unsqueeze(0)})
|
||||
|
||||
# Pop it back down to 5.5kHz for an accurate comparison with the other diffusers.
|
||||
real_resampled = torchaudio.functional.resample(real_resampled.squeeze(0), sample_rate, 5500).unsqueeze(0)
|
||||
gen = torchaudio.functional.resample(gen.squeeze(0), sample_rate, 5500).unsqueeze(0)
|
||||
return gen, real_resampled, 5500
|
||||
|
||||
|
||||
def perform_diffusion_vocoder(self, audio, codes, text, sample_rate=5500):
|
||||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
||||
|
@ -104,6 +133,25 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
mel = wav_to_mel(sample)
|
||||
return projector.get_speech_projection(mel).squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim]
|
||||
|
||||
def load_w2v(self):
|
||||
return Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli")
|
||||
|
||||
def intelligibility_loss(self, w2v, sample, real_sample, sample_rate, real_text):
|
||||
"""
|
||||
Measures the differences between CTC losses using wav2vec2 against the real sample and the generated sample.
|
||||
"""
|
||||
text_codes = torch.tensor(text_to_sequence(real_text), device=sample.device)
|
||||
results = []
|
||||
for s in [sample, real_sample]:
|
||||
s = torchaudio.functional.resample(s, sample_rate, 16000)
|
||||
norm_s = (s - s.mean()) / torch.sqrt(s.var() + 1e-7)
|
||||
norm_s = norm_s.squeeze(1)
|
||||
loss = w2v(input_values=norm_s, labels=text_codes).loss
|
||||
results.append(loss)
|
||||
gen_loss, real_loss = results
|
||||
return gen_loss - real_loss
|
||||
|
||||
|
||||
def compute_frechet_distance(self, proj1, proj2):
|
||||
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
||||
proj1 = proj1.cpu().numpy()
|
||||
|
@ -123,6 +171,9 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
if hasattr(self, 'dvae'):
|
||||
self.dvae = self.dvae.to(self.env['device'])
|
||||
|
||||
w2v = self.load_w2v().to(self.env['device'])
|
||||
w2v.eval()
|
||||
|
||||
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
||||
rng_state = torch.get_rng_state()
|
||||
torch.manual_seed(5)
|
||||
|
@ -131,6 +182,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
with torch.no_grad():
|
||||
gen_projections = []
|
||||
real_projections = []
|
||||
intelligibility_losses = []
|
||||
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
||||
path, text, codes = self.data[i + self.env['rank']]
|
||||
audio = load_audio(path, 22050).to(self.dev)
|
||||
|
@ -139,33 +191,51 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
||||
real_projections.append(self.project(projector, ref, sample_rate).cpu())
|
||||
intelligibility_losses.append(self.intelligibility_loss(w2v, sample, ref, sample_rate, text))
|
||||
|
||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate)
|
||||
gen_projections = torch.stack(gen_projections, dim=0)
|
||||
real_projections = torch.stack(real_projections, dim=0)
|
||||
intelligibility_loss = torch.stack(intelligibility_losses, dim=0).mean()
|
||||
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
||||
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
distributed.all_reduce(frechet_distance)
|
||||
frechet_distance = frechet_distance / distributed.get_world_size()
|
||||
distributed.all_reduce(intelligibility_loss)
|
||||
intelligibility_loss = intelligibility_loss / distributed.get_world_size()
|
||||
|
||||
self.model.train()
|
||||
if hasattr(self, 'dvae'):
|
||||
self.dvae = self.dvae.to('cpu')
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
return {"frechet_distance": frechet_distance}
|
||||
return {"frechet_distance": frechet_distance, "intelligibility_loss": intelligibility_loss}
|
||||
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
from utils.util import load_model_from_config
|
||||
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator',
|
||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\56500_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 2, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from utils.util import load_model_from_config
|
||||
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator',
|
||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\47500_generator_ema.pth').cuda()
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_vocoder_clips_from_dvae_archived_r3_b256_conditioning\\config.yml', 'generator',
|
||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_vocoder_clips_from_dvae_archived_r3_b256_conditioning\\models\\80800_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 1, 'device': 'cuda', 'opt': {}}
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'original_vocoder'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 4, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user