Two more tools to test the audio segmentor

This commit is contained in:
James Betker 2021-08-17 09:09:11 -06:00
parent 7c086d0c2c
commit 8332923f5c
5 changed files with 148 additions and 12 deletions

View File

@ -219,7 +219,7 @@ class TextMelCollate():
def save_mel_buffer_to_file(mel, path):, mel.numpy()), mel.cpu().numpy())
def dump_mels_to_disk():

View File

@ -64,24 +64,26 @@ class GptSegmentor(nn.Module):
self.final_norm = nn.LayerNorm(model_dim)
self.stop_head = nn.Linear(model_dim, 1)
def forward(self, mel_inputs, termination_points):
def forward(self, mel_inputs, termination_points=None):
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
enc = self.gpt(mel_emb)
stop_logits = self.final_norm(enc)
stop_logits = self.stop_head(stop_logits)
if termination_points is not None:
# The MEL gets decimated to 1/4 the size by the encoder, so we need to do the same to the termination points.
termination_points = F.interpolate(termination_points.unsqueeze(1), size=mel_emb.shape[1], mode='area').squeeze()
termination_points = (termination_points > 0).float()
# Compute loss
b, s, _ = enc.shape
stop_logits = self.final_norm(enc)
stop_logits = self.stop_head(stop_logits)
loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), termination_points)
return loss.mean()
return stop_logits

View File

@ -0,0 +1,106 @@
import os.path as osp
import logging
import random
import argparse
import audio2numpy
import torchvision
from munch import munchify
import utils
import utils.options as option
import utils.util as util
from import save_mel_buffer_to_file
from models.tacotron2 import hparams
from models.tacotron2.layers import TacotronSTFT
from models.tacotron2.text import sequence_to_text
from import Vocoder
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import numpy as np
from import wavfile
def forward_pass(model, data, output_dir, opt, b):
with torch.no_grad():
model.feed_data(data, 0)
if 'real_text' in opt['eval'].keys():
real = data[opt['eval']['real_text']][0]
print(f'{b} Real text: "{real}"')
pred_seq = model.eval_state[opt['eval']['gen_text']][0]
pred_text = [sequence_to_text(ts) for ts in pred_seq]
audio = model.eval_state[opt['eval']['audio']][0].cpu().numpy()
wavfile.write(osp.join(output_dir, f'{b}_clip.wav'), 22050, audio)
for i, text in enumerate(pred_text):
print(f'{b} Predicted text {i}: "{text}"')
if __name__ == "__main__":
input_file = "E:\\audio\\books\\Roald Dahl Audiobooks\\Roald Dahl - The BFG\\(Roald Dahl) The BFG - 07.mp3"
config = "../options/train_gpt_stop_libritts.yml"
cutoff_pred_percent = .2
# Set seeds
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default=config)
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
hp = munchify(hparams.create_hparams())
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
model = ExtensibleTrainer(opt)
assert len(model.networks) == 1
model = model.networks[next(iter(model.networks.keys()))]'cuda')
vocoder = Vocoder()
audio, sr = audio2numpy.audio_from_file(input_file)
if len(audio.shape) == 2:
audio = audio[:, 0]
audio = torch.tensor(audio, device='cuda').unsqueeze(0).unsqueeze(0)
audio = torch.nn.functional.interpolate(audio, scale_factor=hp.sampling_rate/sr, mode='nearest').squeeze(1)
stft = TacotronSTFT(hp.filter_length, hp.hop_length, hp.win_length, hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin, hp.mel_fmax).to('cuda')
mels = stft.mel_spectrogram(audio)
with torch.no_grad():
sentence_number = 0
last_detection_start = 0
start = 0
clip_size = model.MAX_MEL_FRAMES
while start+clip_size < mels.shape[-1]:
clip = mels[:, :, start:start+clip_size]
preds = torch.nn.functional.sigmoid(model(clip)).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.
indices = torch.nonzero(preds > cutoff_pred_percent)
for i in indices:
i = i.item()
sentence = mels[0, :, last_detection_start:start+i]
if sentence.shape[-1] > 400 and sentence.shape[-1] < 1600:
save_mel_buffer_to_file(sentence, f'{sentence_number}.npy')
wav = vocoder.transform_mel_to_audio(sentence)
wavfile.write(f'{sentence_number}.wav', 22050, wav[0].cpu().numpy())
sentence_number += 1
last_detection_start = start+i
start += 4
if last_detection_start > start:
start = last_detection_start

View File

@ -0,0 +1,28 @@
import numpy
import torch
from import wavfile
from models.waveglow.waveglow import WaveGlow
class Vocoder:
def __init__(self):
self.model = WaveGlow(n_mel_channels=80, n_flows=12, n_group=8, n_early_size=2, n_early_every=4, WN_config={'n_layers': 8, 'n_channels': 256, 'kernel_size': 3})
sd = torch.load('../experiments/waveglow_256channels_universal_v5.pth')
self.model ='cuda')
def transform_mel_to_audio(self, mel):
if len(mel.shape) == 2: # Assume it's missing the batch dimension and fix that.
mel = mel.unsqueeze(0)
with torch.no_grad():
return self.model.infer(mel)
if __name__ == '__main__':
inp = '3.npy'
mel = torch.tensor(numpy.load(inp)).to('cuda')
vocoder = Vocoder()
wav = vocoder.transform_mel_to_audio(mel)
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy())

View File

@ -282,7 +282,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_stop_libritts.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()