forked from mrq/DL-Art-School
GPT_asr for inference
This commit is contained in:
parent
e1bdd3f7c7
commit
007976082b
|
@ -62,9 +62,9 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
# separate filename and text
|
||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||
audiopath = os.path.join(self.path, audiopath)
|
||||
text = self.get_text(text)
|
||||
text_seq = self.get_text(text)
|
||||
mel = self.get_mel(audiopath)
|
||||
return (text, mel, audiopath_and_text[0])
|
||||
return (text_seq, mel, text, audiopath_and_text[0])
|
||||
|
||||
def get_mel(self, filename):
|
||||
if not self.load_mel_from_disk:
|
||||
|
@ -106,30 +106,31 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
return text_norm
|
||||
|
||||
def __getitem__(self, index):
|
||||
t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index])
|
||||
if m is None or \
|
||||
(self.max_mel_len is not None and m.shape[-1] > self.max_mel_len) or \
|
||||
(self.max_text_len is not None and t.shape[0] > self.max_text_len):
|
||||
if m is not None:
|
||||
print(f"Exception {index} mel_len:{m.shape[-1]} text_len:{t.shape[0]} fname: {p}")
|
||||
tseq, mel, text, path = self.get_mel_text_pair(self.audiopaths_and_text[index])
|
||||
if mel is None or \
|
||||
(self.max_mel_len is not None and mel.shape[-1] > self.max_mel_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
if mel is not None:
|
||||
print(f"Exception {index} mel_len:{mel.shape[-1]} text_len:{tseq.shape[0]} fname: {path}")
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
rv = random.randint(0,len(self)-1)
|
||||
return self[rv]
|
||||
orig_output = m.shape[-1]
|
||||
orig_text_len = t.shape[0]
|
||||
orig_output = mel.shape[-1]
|
||||
orig_text_len = tseq.shape[0]
|
||||
if not self.needs_collate:
|
||||
if m.shape[-1] != self.max_mel_len:
|
||||
m = F.pad(m, (0, self.max_mel_len - m.shape[-1]))
|
||||
if t.shape[0] != self.max_text_len:
|
||||
t = F.pad(t, (0, self.max_text_len - t.shape[0]))
|
||||
if mel.shape[-1] != self.max_mel_len:
|
||||
mel = F.pad(mel, (0, self.max_mel_len - mel.shape[-1]))
|
||||
if tseq.shape[0] != self.max_text_len:
|
||||
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
|
||||
return {
|
||||
'padded_text': t,
|
||||
'real_text': text,
|
||||
'padded_text': tseq,
|
||||
'input_lengths': torch.tensor(orig_text_len, dtype=torch.long),
|
||||
'padded_mel': m,
|
||||
'padded_mel': mel,
|
||||
'output_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||
'filenames': p
|
||||
'filenames': path
|
||||
}
|
||||
return t, m, p
|
||||
return tseq, mel, path, text
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_and_text)
|
||||
|
@ -156,10 +157,12 @@ class TextMelCollate():
|
|||
text_padded = torch.LongTensor(len(batch), max_input_len)
|
||||
text_padded.zero_()
|
||||
filenames = []
|
||||
real_text = []
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
text = batch[ids_sorted_decreasing[i]][0]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
filenames.append(batch[ids_sorted_decreasing[i]][2])
|
||||
real_text.append(batch[ids_sorted_decreasing[i]][3])
|
||||
|
||||
# Right zero-pad mel-spec
|
||||
num_mels = batch[0][1].size(0)
|
||||
|
@ -186,7 +189,8 @@ class TextMelCollate():
|
|||
'padded_mel': mel_padded,
|
||||
'padded_gate': gate_padded,
|
||||
'output_lengths': output_lengths,
|
||||
'filenames': filenames
|
||||
'filenames': filenames,
|
||||
'real_text': real_text,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from munch import munchify
|
|||
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.text import symbols
|
||||
from models.tacotron2.text import symbols, sequence_to_text
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -125,11 +125,11 @@ class GptAsr(nn.Module):
|
|||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
||||
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
|
||||
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width)
|
||||
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
|
||||
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
|
||||
probabilities, sort_indices = torch.sort(probabilities, descending=True)
|
||||
probabilities = probabilities[:beam_width]
|
||||
|
@ -140,15 +140,12 @@ class GptAsr(nn.Module):
|
|||
text_seq = text_seq[sort_indices]
|
||||
text_seq = text_seq[:beam_width]
|
||||
|
||||
if torch.all(torch.any(text_seq == self.MEL_STOP_TOKEN, dim=1)):
|
||||
# PAD doubles as a stop token. PAD=0.
|
||||
if torch.all(torch.any(text_seq == 0, dim=1)):
|
||||
break
|
||||
|
||||
if text_seq.shape[1] >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||
|
||||
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
||||
text_seq = text_seq[0, 1:-1].unsqueeze(0) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT
|
||||
text_seq = text_seq * (text_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
|
||||
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
|
||||
|
||||
return text_seq
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from models.tacotron2.text import cleaners
|
||||
from models.tacotron2.text.symbols import symbols
|
||||
|
||||
|
@ -44,6 +47,8 @@ def sequence_to_text(sequence):
|
|||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if isinstance(symbol_id, torch.Tensor):
|
||||
symbol_id = symbol_id.item()
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
|
|
76
codes/scripts/audio/test_audio_speech_recognition.py
Normal file
76
codes/scripts/audio/test_audio_speech_recognition.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import os.path as osp
|
||||
import logging
|
||||
import random
|
||||
import argparse
|
||||
|
||||
import torchvision
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from models.tacotron2.text import sequence_to_text
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
def forward_pass(model, data, output_dir, opt, b):
|
||||
with torch.no_grad():
|
||||
model.feed_data(data, 0)
|
||||
model.test()
|
||||
|
||||
real = data[opt['eval']['real_text']][0]
|
||||
pred_seq = model.eval_state[opt['eval']['gen_text']][0]
|
||||
pred_text = [sequence_to_text(ts) for ts in pred_seq]
|
||||
audio = model.eval_state[opt['eval']['audio']][0].cpu().numpy()
|
||||
wavfile.write(osp.join(output_dir, f'{b}_clip.wav'), 22050, audio)
|
||||
print(f'{b} Real text: "{real}"')
|
||||
for i, text in enumerate(pred_text):
|
||||
print(f'{b} Predicted text {i}: "{text}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set seeds
|
||||
torch.manual_seed(5555)
|
||||
random.seed(5555)
|
||||
np.random.seed(5555)
|
||||
|
||||
#### options
|
||||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mozcv.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
||||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items()
|
||||
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
|
||||
test_loaders = []
|
||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
||||
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
|
||||
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
test_loaders.append(test_loader)
|
||||
|
||||
model = ExtensibleTrainer(opt)
|
||||
|
||||
batch = 0
|
||||
for test_loader in test_loaders:
|
||||
dataset_dir = opt['path']['results_root']
|
||||
util.mkdir(dataset_dir)
|
||||
|
||||
tq = tqdm(test_loader)
|
||||
for data in tq:
|
||||
forward_pass(model, data, dataset_dir, opt, batch)
|
||||
batch += 1
|
||||
|
Loading…
Reference in New Issue
Block a user