2021-12-11 15:16:50 +00:00
|
|
|
import argparse
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import yaml
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
2023-03-21 15:39:28 +00:00
|
|
|
from dlas.data import create_dataloader, create_dataset
|
|
|
|
from dlas.scripts.audio.gen.speech_synthesis_utils import wav_to_univnet_mel
|
|
|
|
from dlas.utils.options import Loader
|
2021-12-11 15:16:50 +00:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
2023-03-21 15:39:28 +00:00
|
|
|
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model',
|
|
|
|
default='D:\\dlas\\options\\train_diffusion_tts9.yml')
|
|
|
|
parser.add_argument(
|
|
|
|
'-key', type=str, help='Key where audio data is stored', default='wav')
|
|
|
|
parser.add_argument('-num_batches', type=int,
|
|
|
|
help='Number of batches to collect to compute the norm', default=50000)
|
2021-12-11 15:16:50 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
with open(args.opt, mode='r') as f:
|
|
|
|
opt = yaml.load(f, Loader=Loader)
|
|
|
|
dopt = opt['datasets']['train']
|
|
|
|
dopt['phase'] = 'train'
|
|
|
|
dataset, collate = create_dataset(dopt, return_collate=True)
|
2023-03-21 15:39:28 +00:00
|
|
|
dataloader = create_dataloader(
|
|
|
|
dataset, dopt, collate_fn=collate, shuffle=True)
|
2021-12-11 15:16:50 +00:00
|
|
|
|
2022-03-16 04:16:59 +00:00
|
|
|
mel_means = []
|
|
|
|
mel_max = -999999999
|
|
|
|
mel_min = 999999999
|
|
|
|
mel_stds = []
|
|
|
|
mel_vars = []
|
2021-12-11 15:16:50 +00:00
|
|
|
for batch in tqdm(dataloader):
|
2022-03-16 04:16:59 +00:00
|
|
|
if len(mel_means) > args.num_batches:
|
2021-12-11 15:16:50 +00:00
|
|
|
break
|
2022-03-16 04:16:59 +00:00
|
|
|
clip = batch[args.key].cuda()
|
|
|
|
for b in range(clip.shape[0]):
|
|
|
|
wav = clip[b].unsqueeze(0)
|
|
|
|
wav = wav[:, :, :batch[f'{args.key}_lengths'][b]]
|
2023-03-21 15:39:28 +00:00
|
|
|
# Caution: make sure this isn't already normed.
|
|
|
|
mel = wav_to_univnet_mel(clip)
|
|
|
|
mel_means.append(mel.mean((0, 2)).cpu())
|
2022-03-16 04:16:59 +00:00
|
|
|
mel_max = max(mel.max().item(), mel_max)
|
|
|
|
mel_min = min(mel.min().item(), mel_min)
|
2023-03-21 15:39:28 +00:00
|
|
|
mel_stds.append(mel.std((0, 2)).cpu())
|
|
|
|
mel_vars.append(mel.var((0, 2)).cpu())
|
2022-03-16 04:16:59 +00:00
|
|
|
mel_means = torch.stack(mel_means).mean(0)
|
|
|
|
mel_stds = torch.stack(mel_stds).mean(0)
|
|
|
|
mel_vars = torch.stack(mel_vars).mean(0)
|
2023-03-21 15:39:28 +00:00
|
|
|
torch.save((mel_means, mel_max, mel_min, mel_stds,
|
|
|
|
mel_vars), 'univnet_mel_norms.pth')
|