forked from mrq/DL-Art-School
108 lines
4.1 KiB
Python
108 lines
4.1 KiB
Python
import os.path as osp
|
|
import logging
|
|
import random
|
|
import argparse
|
|
|
|
import audio2numpy
|
|
import torchvision
|
|
from munch import munchify
|
|
|
|
import utils
|
|
import utils.options as option
|
|
import utils.util as util
|
|
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
|
|
from models.tacotron2 import hparams
|
|
from models.tacotron2.layers import TacotronSTFT
|
|
from models.tacotron2.text import sequence_to_text
|
|
from scripts.audio.use_vocoder import Vocoder
|
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
|
from data import create_dataset, create_dataloader
|
|
from tqdm import tqdm
|
|
import torch
|
|
import numpy as np
|
|
from scipy.io import wavfile
|
|
|
|
|
|
def forward_pass(model, data, output_dir, opt, b):
|
|
with torch.no_grad():
|
|
model.feed_data(data, 0)
|
|
model.test()
|
|
|
|
if 'real_text' in opt['eval'].keys():
|
|
real = data[opt['eval']['real_text']][0]
|
|
print(f'{b} Real text: "{real}"')
|
|
|
|
pred_seq = model.eval_state[opt['eval']['gen_text']][0]
|
|
pred_text = [sequence_to_text(ts) for ts in pred_seq]
|
|
audio = model.eval_state[opt['eval']['audio']][0].cpu().numpy()
|
|
wavfile.write(osp.join(output_dir, f'{b}_clip.wav'), 22050, audio)
|
|
for i, text in enumerate(pred_text):
|
|
print(f'{b} Predicted text {i}: "{text}"')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
input_file = "E:\\audio\\books\\Roald Dahl Audiobooks\\Roald Dahl - The BFG\\(Roald Dahl) The BFG - 07.mp3"
|
|
config = "../options/train_gpt_stop_libritts.yml"
|
|
cutoff_pred_percent = .2
|
|
|
|
# Set seeds
|
|
torch.manual_seed(5555)
|
|
random.seed(5555)
|
|
np.random.seed(5555)
|
|
|
|
#### options
|
|
torch.backends.cudnn.benchmark = True
|
|
want_metrics = False
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default=config)
|
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
|
opt = option.dict_to_nonedict(opt)
|
|
utils.util.loaded_options = opt
|
|
hp = munchify(hparams.create_hparams())
|
|
|
|
util.mkdirs(
|
|
(path for key, path in opt['path'].items()
|
|
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
|
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
|
screen=True, tofile=True)
|
|
logger = logging.getLogger('base')
|
|
logger.info(option.dict2str(opt))
|
|
|
|
model = ExtensibleTrainer(opt)
|
|
assert len(model.networks) == 1
|
|
model = model.networks[next(iter(model.networks.keys()))].module.to('cuda')
|
|
model.eval()
|
|
|
|
vocoder = Vocoder()
|
|
|
|
audio, sr = audio2numpy.audio_from_file(input_file)
|
|
if len(audio.shape) == 2:
|
|
audio = audio[:, 0]
|
|
audio = torch.tensor(audio, device='cuda').unsqueeze(0).unsqueeze(0)
|
|
audio = torch.nn.functional.interpolate(audio, scale_factor=hp.sampling_rate/sr, mode='nearest').squeeze(1)
|
|
stft = TacotronSTFT(hp.filter_length, hp.hop_length, hp.win_length, hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin, hp.mel_fmax).to('cuda')
|
|
mels = stft.mel_spectrogram(audio)
|
|
|
|
with torch.no_grad():
|
|
sentence_number = 0
|
|
last_detection_start = 0
|
|
start = 0
|
|
clip_size = model.max_mel_frames
|
|
while start+clip_size < mels.shape[-1]:
|
|
clip = mels[:, :, start:start+clip_size]
|
|
pred_starts, pred_ends = model(clip)
|
|
pred_ends = torch.nn.functional.sigmoid(pred_ends).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.
|
|
indices = torch.nonzero(pred_ends > cutoff_pred_percent)
|
|
for i in indices:
|
|
i = i.item()
|
|
sentence = mels[0, :, last_detection_start:start+i]
|
|
if sentence.shape[-1] > 400 and sentence.shape[-1] < 1600:
|
|
save_mel_buffer_to_file(sentence, f'{sentence_number}.npy')
|
|
wav = vocoder.transform_mel_to_audio(sentence)
|
|
wavfile.write(f'{sentence_number}.wav', 22050, wav[0].cpu().numpy())
|
|
sentence_number += 1
|
|
last_detection_start = start+i
|
|
start += 4
|
|
if last_detection_start > start:
|
|
start = last_detection_start
|