2021-12-11 15:16:50 +00:00
import argparse
import torch
import yaml
from tqdm import tqdm
from data import create_dataset , create_dataloader
2022-03-16 04:16:59 +00:00
from scripts . audio . gen . speech_synthesis_utils import wav_to_univnet_mel
2021-12-11 15:16:50 +00:00
from utils . options import Loader
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
2022-03-16 04:16:59 +00:00
parser . add_argument ( ' -opt ' , type = str , help = ' Path to options YAML file used to train the diffusion model ' , default = ' D: \\ dlas \\ options \\ train_diffusion_tts9.yml ' )
parser . add_argument ( ' -key ' , type = str , help = ' Key where audio data is stored ' , default = ' wav ' )
parser . add_argument ( ' -num_batches ' , type = int , help = ' Number of batches to collect to compute the norm ' , default = 50000 )
2021-12-11 15:16:50 +00:00
args = parser . parse_args ( )
with open ( args . opt , mode = ' r ' ) as f :
opt = yaml . load ( f , Loader = Loader )
dopt = opt [ ' datasets ' ] [ ' train ' ]
dopt [ ' phase ' ] = ' train '
dataset , collate = create_dataset ( dopt , return_collate = True )
dataloader = create_dataloader ( dataset , dopt , collate_fn = collate , shuffle = True )
2022-03-16 04:16:59 +00:00
mel_means = [ ]
mel_max = - 999999999
mel_min = 999999999
mel_stds = [ ]
mel_vars = [ ]
2021-12-11 15:16:50 +00:00
for batch in tqdm ( dataloader ) :
2022-03-16 04:16:59 +00:00
if len ( mel_means ) > args . num_batches :
2021-12-11 15:16:50 +00:00
break
2022-03-16 04:16:59 +00:00
clip = batch [ args . key ] . cuda ( )
for b in range ( clip . shape [ 0 ] ) :
wav = clip [ b ] . unsqueeze ( 0 )
wav = wav [ : , : , : batch [ f ' { args . key } _lengths ' ] [ b ] ]
mel = wav_to_univnet_mel ( clip ) # Caution: make sure this isn't already normed.
mel_means . append ( mel . mean ( ( 0 , 2 ) ) . cpu ( ) )
mel_max = max ( mel . max ( ) . item ( ) , mel_max )
mel_min = min ( mel . min ( ) . item ( ) , mel_min )
mel_stds . append ( mel . std ( ( 0 , 2 ) ) . cpu ( ) )
mel_vars . append ( mel . var ( ( 0 , 2 ) ) . cpu ( ) )
mel_means = torch . stack ( mel_means ) . mean ( 0 )
mel_stds = torch . stack ( mel_stds ) . mean ( 0 )
mel_vars = torch . stack ( mel_vars ) . mean ( 0 )
torch . save ( ( mel_means , mel_max , mel_min , mel_stds , mel_vars ) , ' univnet_mel_norms.pth ' )