forked from mrq/DL-Art-School
audio diffusion frechet distance measurement!
This commit is contained in:
parent
23a310b488
commit
d1d1ae32a1
|
@ -9,14 +9,8 @@ from data.util import find_files_of_type, is_audio_file
|
||||||
from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \
|
from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \
|
||||||
load_discrete_vocoder_diffuser, wav_to_mel, convert_mel_to_codes
|
load_discrete_vocoder_diffuser, wav_to_mel, convert_mel_to_codes
|
||||||
from utils.audio import plot_spectrogram
|
from utils.audio import plot_spectrogram
|
||||||
from utils.util import load_model_from_config
|
from utils.util import load_model_from_config, ceil_multiple
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
def ceil_multiple(base, multiple):
|
|
||||||
res = base % multiple
|
|
||||||
if res == 0:
|
|
||||||
return base
|
|
||||||
return base + (multiple - res)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ctc_codes_for(src_clip_path):
|
def get_ctc_codes_for(src_clip_path):
|
||||||
|
@ -34,6 +28,18 @@ def get_ctc_codes_for(src_clip_path):
|
||||||
return torch.argmax(logits, dim=-1), clip
|
return torch.argmax(logits, dim=-1), clip
|
||||||
|
|
||||||
|
|
||||||
|
def determine_output_size(codes, base_sample_rate)
|
||||||
|
aligned_codes_compression_factor = base_sample_rate * 221 // 11025
|
||||||
|
output_size = codes.shape[-1]*aligned_codes_compression_factor
|
||||||
|
padded_size = ceil_multiple(output_size, 2048)
|
||||||
|
padding_added = padded_size - output_size
|
||||||
|
padding_needed_for_codes = padding_added // aligned_codes_compression_factor
|
||||||
|
if padding_needed_for_codes > 0:
|
||||||
|
codes = F.pad(codes, (0, padding_needed_for_codes))
|
||||||
|
output_shape = (1, 1, padded_size)
|
||||||
|
return output_shape, codes
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
provided_voices = {
|
provided_voices = {
|
||||||
# Male
|
# Male
|
||||||
|
@ -70,7 +76,6 @@ if __name__ == '__main__':
|
||||||
diffusion = load_model_from_config(args.opt, args.diffusion_model_name, also_load_savepoint=False,
|
diffusion = load_model_from_config(args.opt, args.diffusion_model_name, also_load_savepoint=False,
|
||||||
load_path=args.diffusion_model_path, device='cpu').eval()
|
load_path=args.diffusion_model_path, device='cpu').eval()
|
||||||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='cosine')
|
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='cosine')
|
||||||
aligned_codes_compression_factor = base_sample_rate * 221 // 11025
|
|
||||||
sr_diffusion = load_model_from_config(args.sr_opt, args.sr_diffusion_model_name, also_load_savepoint=False,
|
sr_diffusion = load_model_from_config(args.sr_opt, args.sr_diffusion_model_name, also_load_savepoint=False,
|
||||||
load_path=args.sr_diffusion_model_path, device='cpu').eval()
|
load_path=args.sr_diffusion_model_path, device='cpu').eval()
|
||||||
sr_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='linear')
|
sr_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='linear')
|
||||||
|
@ -90,7 +95,7 @@ if __name__ == '__main__':
|
||||||
torchaudio.save(os.path.join(args.output_path, f'{e}_source_clip.wav'), src_clip.unsqueeze(0).cpu(), 16000)
|
torchaudio.save(os.path.join(args.output_path, f'{e}_source_clip.wav'), src_clip.unsqueeze(0).cpu(), 16000)
|
||||||
|
|
||||||
print("Performing initial diffusion..")
|
print("Performing initial diffusion..")
|
||||||
output_shape = (1, 1, ceil_multiple(aligned_codes.shape[-1]*aligned_codes_compression_factor, 2048))
|
output_shape, aligned_codes = determine_output_size(aligned_codes, base_sample_rate)
|
||||||
diffusion = diffusion.cuda()
|
diffusion = diffusion.cuda()
|
||||||
output_base = diffuser.p_sample_loop(diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device),
|
output_base = diffuser.p_sample_loop(diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device),
|
||||||
model_kwargs={'tokens': aligned_codes,
|
model_kwargs={'tokens': aligned_codes,
|
||||||
|
|
115
codes/trainer/eval/audio_diffusion_fid.py
Normal file
115
codes/trainer/eval/audio_diffusion_fid.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import torchvision
|
||||||
|
from pytorch_fid import fid_score
|
||||||
|
from pytorch_fid.fid_score import calculate_frechet_distance
|
||||||
|
from torch import distributed
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import trainer.eval.evaluator as evaluator
|
||||||
|
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
||||||
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
|
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser
|
||||||
|
|
||||||
|
|
||||||
|
class StyleTransferEvaluator(evaluator.Evaluator):
|
||||||
|
"""
|
||||||
|
Evaluator produces generate from a diffusion model, then uses a pretrained wav2vec model to compute a frechet
|
||||||
|
distance between real and fake samples.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, opt_eval, env):
|
||||||
|
super().__init__(model, opt_eval, env, uses_all_ddp=True)
|
||||||
|
self.real_path = opt_eval['eval_tsv']
|
||||||
|
self.data = load_tsv_aligned_codes(self.real_path)
|
||||||
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
|
self.skip = distributed.get_world_size() # One batch element per GPU.
|
||||||
|
else:
|
||||||
|
self.skip = 1
|
||||||
|
diffusion_steps = opt_get(opt_eval, ['diffusion_steps'], 50)
|
||||||
|
diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'cosine')
|
||||||
|
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule)
|
||||||
|
self.dev = self.env['device']
|
||||||
|
|
||||||
|
def perform_diffusion(self, audio, codes, sample_rate=5500):
|
||||||
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
|
aligned_codes_compression_factor = sample_rate * 221 // 11025
|
||||||
|
output_size = codes.shape[-1]*aligned_codes_compression_factor
|
||||||
|
padded_size = ceil_multiple(output_size, 2048)
|
||||||
|
padding_added = padded_size - output_size
|
||||||
|
padding_needed_for_codes = padding_added // aligned_codes_compression_factor
|
||||||
|
if padding_needed_for_codes > 0:
|
||||||
|
codes = F.pad(codes, (0, padding_needed_for_codes))
|
||||||
|
output_shape = (1, 1, padded_size)
|
||||||
|
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||||
|
model_kwargs={'tokens': codes.unsqueeze(0),
|
||||||
|
'conditioning_input': real_resampled})
|
||||||
|
return gen, real_resampled, sample_rate
|
||||||
|
|
||||||
|
def project(self, projector, sample, sample_rate):
|
||||||
|
sample = torchaudio.functional.resample(sample, sample_rate, 16000)
|
||||||
|
sample = (sample - sample.mean()) / torch.sqrt(sample.var() + 1e-7)
|
||||||
|
return projector(sample.squeeze(1), output_hidden_states=True).hidden_states[-1].squeeze(0) # Getting rid of the batch dimension means it's just [seq_len,hidden_states]
|
||||||
|
|
||||||
|
def compute_frechet_distance(self, proj1, proj2):
|
||||||
|
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
||||||
|
proj1 = proj1.cpu().numpy()
|
||||||
|
proj2 = proj2.cpu().numpy()
|
||||||
|
mu1 = np.mean(proj1, axis=0)
|
||||||
|
mu2 = np.mean(proj2, axis=0)
|
||||||
|
sigma1 = np.cov(proj1, rowvar=False)
|
||||||
|
sigma2 = np.cov(proj2, rowvar=False)
|
||||||
|
return torch.tensor(calculate_frechet_distance(mu1, sigma1, mu2, sigma2))
|
||||||
|
|
||||||
|
def perform_eval(self):
|
||||||
|
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
|
projector = Wav2Vec2ForCTC.from_pretrained(f"facebook/wav2vec2-large").to(self.dev)
|
||||||
|
projector.eval()
|
||||||
|
|
||||||
|
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
||||||
|
rng_state = torch.get_rng_state()
|
||||||
|
torch.manual_seed(5)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
gen_projections = []
|
||||||
|
real_projections = []
|
||||||
|
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
||||||
|
path, text, codes = self.data[i + self.env['rank']]
|
||||||
|
audio = load_audio(path, 22050).to(self.dev)
|
||||||
|
codes = codes.to(self.dev)
|
||||||
|
sample, ref, sample_rate = self.perform_diffusion(audio, codes)
|
||||||
|
|
||||||
|
gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
||||||
|
real_projections.append(self.project(projector, ref, sample_rate).cpu())
|
||||||
|
|
||||||
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
||||||
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate)
|
||||||
|
gen_projections = torch.cat(gen_projections, dim=0)
|
||||||
|
real_projections = torch.cat(real_projections, dim=0)
|
||||||
|
fid = self.compute_frechet_distance(gen_projections, real_projections)
|
||||||
|
|
||||||
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
|
fid = distributed.all_reduce(fid) / distributed.get_world_size()
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
|
torch.set_rng_state(rng_state)
|
||||||
|
|
||||||
|
return {"fid": fid}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from utils.util import load_model_from_config, ceil_multiple, opt_get
|
||||||
|
|
||||||
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts5_medium.yml', 'generator',
|
||||||
|
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\73000_generator_ema.pth').cuda()
|
||||||
|
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50}
|
||||||
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda'}
|
||||||
|
eval = StyleTransferEvaluator(diffusion, opt_eval, env)
|
||||||
|
eval.perform_eval()
|
|
@ -498,3 +498,13 @@ def map_cuda_to_correct_device(storage, loc):
|
||||||
return storage.cuda(torch.cuda.current_device())
|
return storage.cuda(torch.cuda.current_device())
|
||||||
else:
|
else:
|
||||||
return storage.cpu()
|
return storage.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_multiple(base, multiple):
|
||||||
|
"""
|
||||||
|
Returns the next closest multiple >= base.
|
||||||
|
"""
|
||||||
|
res = base % multiple
|
||||||
|
if res == 0:
|
||||||
|
return base
|
||||||
|
return base + (multiple - res)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user