Two more tools to test the audio segmentor

This commit is contained in:
James Betker 2021-08-17 09:09:11 -06:00
parent 7c086d0c2c
commit 8332923f5c
5 changed files with 148 additions and 12 deletions

View File

@ -219,7 +219,7 @@ class TextMelCollate():
def save_mel_buffer_to_file(mel, path):
np.save(path, mel.numpy())
np.save(path, mel.cpu().numpy())
def dump_mels_to_disk():

View File

@ -64,24 +64,26 @@ class GptSegmentor(nn.Module):
self.final_norm = nn.LayerNorm(model_dim)
self.stop_head = nn.Linear(model_dim, 1)
def forward(self, mel_inputs, termination_points):
def forward(self, mel_inputs, termination_points=None):
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
enc = self.gpt(mel_emb)
stop_logits = self.final_norm(enc)
stop_logits = self.stop_head(stop_logits)
if termination_points is not None:
# The MEL gets decimated to 1/4 the size by the encoder, so we need to do the same to the termination points.
termination_points = F.interpolate(termination_points.unsqueeze(1), size=mel_emb.shape[1], mode='area').squeeze()
termination_points = (termination_points > 0).float()
# Compute loss
b, s, _ = enc.shape
stop_logits = self.final_norm(enc)
stop_logits = self.stop_head(stop_logits)
loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), termination_points)
return loss.mean()
else:
return stop_logits
@register_model

View File

@ -0,0 +1,106 @@
import os.path as osp
import logging
import random
import argparse
import audio2numpy
import torchvision
from munch import munchify
import utils
import utils.options as option
import utils.util as util
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
from models.tacotron2 import hparams
from models.tacotron2.layers import TacotronSTFT
from models.tacotron2.text import sequence_to_text
from scripts.audio.use_vocoder import Vocoder
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import numpy as np
from scipy.io import wavfile
def forward_pass(model, data, output_dir, opt, b):
with torch.no_grad():
model.feed_data(data, 0)
model.test()
if 'real_text' in opt['eval'].keys():
real = data[opt['eval']['real_text']][0]
print(f'{b} Real text: "{real}"')
pred_seq = model.eval_state[opt['eval']['gen_text']][0]
pred_text = [sequence_to_text(ts) for ts in pred_seq]
audio = model.eval_state[opt['eval']['audio']][0].cpu().numpy()
wavfile.write(osp.join(output_dir, f'{b}_clip.wav'), 22050, audio)
for i, text in enumerate(pred_text):
print(f'{b} Predicted text {i}: "{text}"')
if __name__ == "__main__":
input_file = "E:\\audio\\books\\Roald Dahl Audiobooks\\Roald Dahl - The BFG\\(Roald Dahl) The BFG - 07.mp3"
config = "../options/train_gpt_stop_libritts.yml"
cutoff_pred_percent = .2
# Set seeds
torch.manual_seed(5555)
random.seed(5555)
np.random.seed(5555)
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default=config)
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
hp = munchify(hparams.create_hparams())
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
model = ExtensibleTrainer(opt)
assert len(model.networks) == 1
model = model.networks[next(iter(model.networks.keys()))].module.to('cuda')
model.eval()
vocoder = Vocoder()
audio, sr = audio2numpy.audio_from_file(input_file)
if len(audio.shape) == 2:
audio = audio[:, 0]
audio = torch.tensor(audio, device='cuda').unsqueeze(0).unsqueeze(0)
audio = torch.nn.functional.interpolate(audio, scale_factor=hp.sampling_rate/sr, mode='nearest').squeeze(1)
stft = TacotronSTFT(hp.filter_length, hp.hop_length, hp.win_length, hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin, hp.mel_fmax).to('cuda')
mels = stft.mel_spectrogram(audio)
with torch.no_grad():
sentence_number = 0
last_detection_start = 0
start = 0
clip_size = model.MAX_MEL_FRAMES
while start+clip_size < mels.shape[-1]:
clip = mels[:, :, start:start+clip_size]
preds = torch.nn.functional.sigmoid(model(clip)).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.
indices = torch.nonzero(preds > cutoff_pred_percent)
for i in indices:
i = i.item()
sentence = mels[0, :, last_detection_start:start+i]
if sentence.shape[-1] > 400 and sentence.shape[-1] < 1600:
save_mel_buffer_to_file(sentence, f'{sentence_number}.npy')
wav = vocoder.transform_mel_to_audio(sentence)
wavfile.write(f'{sentence_number}.wav', 22050, wav[0].cpu().numpy())
sentence_number += 1
last_detection_start = start+i
start += 4
if last_detection_start > start:
start = last_detection_start

View File

@ -0,0 +1,28 @@
import numpy
import torch
from scipy.io import wavfile
from models.waveglow.waveglow import WaveGlow
class Vocoder:
def __init__(self):
self.model = WaveGlow(n_mel_channels=80, n_flows=12, n_group=8, n_early_size=2, n_early_every=4, WN_config={'n_layers': 8, 'n_channels': 256, 'kernel_size': 3})
sd = torch.load('../experiments/waveglow_256channels_universal_v5.pth')
self.model.load_state_dict(sd)
self.model = self.model.to('cuda')
self.model.eval()
def transform_mel_to_audio(self, mel):
if len(mel.shape) == 2: # Assume it's missing the batch dimension and fix that.
mel = mel.unsqueeze(0)
with torch.no_grad():
return self.model.infer(mel)
if __name__ == '__main__':
inp = '3.npy'
mel = torch.tensor(numpy.load(inp)).to('cuda')
vocoder = Vocoder()
wav = vocoder.transform_mel_to_audio(mel)
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy())

View File

@ -282,7 +282,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_stop_libritts.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()