add mel_norm to std injector
This commit is contained in:
parent
0fc877cbc8
commit
3f244f6a68
|
@ -5,14 +5,14 @@ import yaml
|
|||
from tqdm import tqdm
|
||||
|
||||
from data import create_dataset, create_dataloader
|
||||
from trainer.injectors.base_injectors import TorchMelSpectrogramInjector
|
||||
from scripts.audio.gen.speech_synthesis_utils import wav_to_univnet_mel
|
||||
from utils.options import Loader
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_dvae_audio_clips.yml')
|
||||
parser.add_argument('-key', type=str, help='Key where audio data is stored', default='clip')
|
||||
parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=10)
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_diffusion_tts9.yml')
|
||||
parser.add_argument('-key', type=str, help='Key where audio data is stored', default='wav')
|
||||
parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=50000)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.opt, mode='r') as f:
|
||||
|
@ -21,14 +21,26 @@ if __name__ == '__main__':
|
|||
dopt['phase'] = 'train'
|
||||
dataset, collate = create_dataset(dopt, return_collate=True)
|
||||
dataloader = create_dataloader(dataset, dopt, collate_fn=collate, shuffle=True)
|
||||
inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel'},{}).cuda()
|
||||
|
||||
mels = []
|
||||
mel_means = []
|
||||
mel_max = -999999999
|
||||
mel_min = 999999999
|
||||
mel_stds = []
|
||||
mel_vars = []
|
||||
for batch in tqdm(dataloader):
|
||||
clip = batch[args.key].cuda()
|
||||
mel = inj({'wav': clip})['mel']
|
||||
mels.append(mel.mean((0,2)).cpu())
|
||||
if len(mels) > args.num_batches:
|
||||
if len(mel_means) > args.num_batches:
|
||||
break
|
||||
mel_norms = torch.stack(mels).mean(0)
|
||||
torch.save(mel_norms, 'mel_norms.pth')
|
||||
clip = batch[args.key].cuda()
|
||||
for b in range(clip.shape[0]):
|
||||
wav = clip[b].unsqueeze(0)
|
||||
wav = wav[:, :, :batch[f'{args.key}_lengths'][b]]
|
||||
mel = wav_to_univnet_mel(clip) # Caution: make sure this isn't already normed.
|
||||
mel_means.append(mel.mean((0,2)).cpu())
|
||||
mel_max = max(mel.max().item(), mel_max)
|
||||
mel_min = min(mel.min().item(), mel_min)
|
||||
mel_stds.append(mel.std((0,2)).cpu())
|
||||
mel_vars.append(mel.var((0,2)).cpu())
|
||||
mel_means = torch.stack(mel_means).mean(0)
|
||||
mel_stds = torch.stack(mel_stds).mean(0)
|
||||
mel_vars = torch.stack(mel_vars).mean(0)
|
||||
torch.save((mel_means,mel_max,mel_min,mel_stds,mel_vars), 'univnet_mel_norms.pth')
|
|
@ -21,6 +21,14 @@ class MelSpectrogramInjector(Injector):
|
|||
mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
|
||||
sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
||||
self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
|
||||
self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None)
|
||||
if self.mel_norm_file is not None:
|
||||
# Note that this format is different from TorchMelSpectrogramInjector.
|
||||
mel_means, mel_max, mel_min, mel_stds, mel_vars = torch.load(self.mel_norm_file)
|
||||
self.mel_max = mel_max
|
||||
self.mel_min = mel_min
|
||||
else:
|
||||
self.mel_norms = None
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
|
@ -28,7 +36,10 @@ class MelSpectrogramInjector(Injector):
|
|||
inp = inp.squeeze(1)
|
||||
assert len(inp.shape) == 2
|
||||
self.stft = self.stft.to(inp.device)
|
||||
return {self.output: self.stft.mel_spectrogram(inp)}
|
||||
mel = self.stft.mel_spectrogram(inp)
|
||||
if self.mel_norms is not None:
|
||||
mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1
|
||||
return {self.output: mel}
|
||||
|
||||
|
||||
class TorchMelSpectrogramInjector(Injector):
|
||||
|
|
Loading…
Reference in New Issue
Block a user