add mel_norm to std injector

This commit is contained in:
James Betker 2022-03-15 22:16:59 -06:00
parent 0fc877cbc8
commit 3f244f6a68
2 changed files with 36 additions and 13 deletions

View File

@ -5,14 +5,14 @@ import yaml
from tqdm import tqdm
from data import create_dataset, create_dataloader
from trainer.injectors.base_injectors import TorchMelSpectrogramInjector
from scripts.audio.gen.speech_synthesis_utils import wav_to_univnet_mel
from utils.options import Loader
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_dvae_audio_clips.yml')
parser.add_argument('-key', type=str, help='Key where audio data is stored', default='clip')
parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=10)
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_diffusion_tts9.yml')
parser.add_argument('-key', type=str, help='Key where audio data is stored', default='wav')
parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=50000)
args = parser.parse_args()
with open(args.opt, mode='r') as f:
@ -21,14 +21,26 @@ if __name__ == '__main__':
dopt['phase'] = 'train'
dataset, collate = create_dataset(dopt, return_collate=True)
dataloader = create_dataloader(dataset, dopt, collate_fn=collate, shuffle=True)
inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel'},{}).cuda()
mels = []
mel_means = []
mel_max = -999999999
mel_min = 999999999
mel_stds = []
mel_vars = []
for batch in tqdm(dataloader):
clip = batch[args.key].cuda()
mel = inj({'wav': clip})['mel']
mels.append(mel.mean((0,2)).cpu())
if len(mels) > args.num_batches:
if len(mel_means) > args.num_batches:
break
mel_norms = torch.stack(mels).mean(0)
torch.save(mel_norms, 'mel_norms.pth')
clip = batch[args.key].cuda()
for b in range(clip.shape[0]):
wav = clip[b].unsqueeze(0)
wav = wav[:, :, :batch[f'{args.key}_lengths'][b]]
mel = wav_to_univnet_mel(clip) # Caution: make sure this isn't already normed.
mel_means.append(mel.mean((0,2)).cpu())
mel_max = max(mel.max().item(), mel_max)
mel_min = min(mel.min().item(), mel_min)
mel_stds.append(mel.std((0,2)).cpu())
mel_vars.append(mel.var((0,2)).cpu())
mel_means = torch.stack(mel_means).mean(0)
mel_stds = torch.stack(mel_stds).mean(0)
mel_vars = torch.stack(mel_vars).mean(0)
torch.save((mel_means,mel_max,mel_min,mel_stds,mel_vars), 'univnet_mel_norms.pth')

View File

@ -21,6 +21,14 @@ class MelSpectrogramInjector(Injector):
mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None)
if self.mel_norm_file is not None:
# Note that this format is different from TorchMelSpectrogramInjector.
mel_means, mel_max, mel_min, mel_stds, mel_vars = torch.load(self.mel_norm_file)
self.mel_max = mel_max
self.mel_min = mel_min
else:
self.mel_norms = None
def forward(self, state):
inp = state[self.input]
@ -28,7 +36,10 @@ class MelSpectrogramInjector(Injector):
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.stft = self.stft.to(inp.device)
return {self.output: self.stft.mel_spectrogram(inp)}
mel = self.stft.mel_spectrogram(inp)
if self.mel_norms is not None:
mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1
return {self.output: mel}
class TorchMelSpectrogramInjector(Injector):