GrandConjoinedDataset
This commit is contained in:
parent
b9de8a8eda
commit
e55d949855
|
@ -86,6 +86,8 @@ def create_dataset(dataset_opt, return_collate=False):
|
||||||
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D
|
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D
|
||||||
elif mode == 'unsupervised_audio_with_noise':
|
elif mode == 'unsupervised_audio_with_noise':
|
||||||
from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D
|
from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D
|
||||||
|
elif mode == 'grand_conjoined_voice':
|
||||||
|
from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||||
dataset = D(dataset_opt)
|
dataset = D(dataset_opt)
|
||||||
|
|
File diff suppressed because one or more lines are too long
157
codes/data/audio/grand_conjoined_dataset.py
Normal file
157
codes/data/audio/grand_conjoined_dataset.py
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
import os
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.data
|
||||||
|
import torchaudio
|
||||||
|
from munch import munchify
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
|
from data.audio.unsupervised_audio_dataset import load_audio, UnsupervisedAudioDataset
|
||||||
|
from data.text.hf_datasets_wrapper import HfDataset
|
||||||
|
from data.util import find_files_of_type, is_audio_file
|
||||||
|
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||||
|
from models.tacotron2.text import text_to_sequence
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def build_paired_voice_dataset(args):
|
||||||
|
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
||||||
|
from models.tacotron2.hparams import create_hparams
|
||||||
|
default_params = create_hparams()
|
||||||
|
default_params.update(args)
|
||||||
|
dataset_opt = munchify(default_params)
|
||||||
|
return D(dataset_opt)
|
||||||
|
|
||||||
|
|
||||||
|
def clamp(x, minimum, maximum):
|
||||||
|
return max(minimum, min(x, maximum))
|
||||||
|
|
||||||
|
|
||||||
|
class GrandConjoinedDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
A joint text & speech dataset that joins three separate datasets into a single batch:
|
||||||
|
1. Unpaired text
|
||||||
|
2. Unpaired speech
|
||||||
|
3. Paired speech & text
|
||||||
|
|
||||||
|
Supports situations where the underlying data sources for these three elements are differently sized, e.g. you can
|
||||||
|
have a massive text corpus of 1B elements, a smaller unpaired speech corpus, and a small paired speech<->text corpus.
|
||||||
|
|
||||||
|
Performs tokenization at this level, ignoring any tokenization performed by upstream datasets.
|
||||||
|
"""
|
||||||
|
def __init__(self, opt):
|
||||||
|
paired_dataset_args = opt['paired_dataset_args']
|
||||||
|
unsupervised_audio_args = opt['unsupervised_audio_args']
|
||||||
|
text_corpus_args = opt['text_corpus_args']
|
||||||
|
sample_rate = 22050
|
||||||
|
|
||||||
|
self.max_paired_audio_length = opt['max_paired_audio_length']
|
||||||
|
self.max_paired_text_length = opt['max_paired_text_length']
|
||||||
|
self.max_solo_audio_length = opt['max_solo_audio_length']
|
||||||
|
self.max_solo_text_length = opt['max_solo_text_length']
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
|
||||||
|
# Set some sane arguments for all three datasets.
|
||||||
|
paired_dataset_args['needs_collate'] = False
|
||||||
|
paired_dataset_args['load_conditioning'] = False
|
||||||
|
paired_dataset_args['sample_rate'] = sample_rate
|
||||||
|
paired_dataset_args['max_wav_length'] = self.max_paired_audio_length
|
||||||
|
paired_dataset_args['max_text_length'] = self.max_paired_text_length
|
||||||
|
unsupervised_audio_args['sampling_rate'] = sample_rate
|
||||||
|
unsupervised_audio_args['do_augmentation'] = False
|
||||||
|
unsupervised_audio_args['resample_clip'] = False
|
||||||
|
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
|
||||||
|
|
||||||
|
self.speech_and_text = build_paired_voice_dataset(paired_dataset_args)
|
||||||
|
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
|
||||||
|
self.text = HfDataset(**text_corpus_args)
|
||||||
|
|
||||||
|
def fetch_text_at(self, i):
|
||||||
|
try:
|
||||||
|
txt = self.text[i % len(self.text)]['text']
|
||||||
|
tok = self.speech_and_text.get_text(txt)
|
||||||
|
padding_required = self.max_solo_text_length - tok.shape[0]
|
||||||
|
if padding_required < 0:
|
||||||
|
# Just truncate since there is no conditioning requried.
|
||||||
|
tok = tok[:self.max_solo_text_length]
|
||||||
|
elif padding_required > 0:
|
||||||
|
tok = F.pad(tok, (0, padding_required))
|
||||||
|
return txt, tok
|
||||||
|
except:
|
||||||
|
# This is fully expected: there are a lot of text strings we intentionally do not
|
||||||
|
# handle (e.g. ones with emojis, or other languages). Just return another one.
|
||||||
|
return self.fetch_text_at((i+1) % len(self.text))
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
snt = self.speech_and_text[i % len(self.speech_and_text)]
|
||||||
|
sp = self.speech[i % len(self.speech)]
|
||||||
|
txt, txt_tok = self.fetch_text_at(i % len(self.text))
|
||||||
|
|
||||||
|
return {
|
||||||
|
'paired_audio': snt['wav'],
|
||||||
|
'paired_audio_lengths': snt['wav_lengths'],
|
||||||
|
'paired_text': snt['real_text'],
|
||||||
|
'paired_text_tokens': snt['padded_text'],
|
||||||
|
'paired_file': snt['filenames'],
|
||||||
|
'speech_audio': sp['clip'],
|
||||||
|
'speech_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length),
|
||||||
|
'speech_file': sp['path'],
|
||||||
|
'text_text': txt,
|
||||||
|
'text_tokens': txt_tok,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return max(len(self.speech), len(self.speech_and_text), len(self.text))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
batch_sz = 8
|
||||||
|
params = {
|
||||||
|
'mode': 'grand_conjoined_voice',
|
||||||
|
'phase': 'train',
|
||||||
|
'n_workers': 0,
|
||||||
|
'batch_size': batch_sz,
|
||||||
|
|
||||||
|
'max_paired_audio_length': 255995,
|
||||||
|
'max_paired_text_length': 80,
|
||||||
|
'max_solo_text_length': 330,
|
||||||
|
'max_solo_audio_length': 300000,
|
||||||
|
'paired_dataset_args': {
|
||||||
|
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
|
||||||
|
'fetcher_mode': ['libritts'],
|
||||||
|
},
|
||||||
|
'unsupervised_audio_args': {
|
||||||
|
'path': ['Z:\\bigasr_dataset\\librispeech\\test_clean'],
|
||||||
|
'cache_path': 'test_cache_delete_me.pth',
|
||||||
|
},
|
||||||
|
'text_corpus_args': {
|
||||||
|
'corpi': [['bookcorpus', '']],
|
||||||
|
'cache_path': 'Z:\\huggingface_datasets\\cache',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
|
ds = create_dataset(params)
|
||||||
|
dl = create_dataloader(ds, params)
|
||||||
|
|
||||||
|
def save(b, i, ib, key):
|
||||||
|
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||||
|
|
||||||
|
def decode(b, ib, key):
|
||||||
|
return ds.speech_and_text.tokenizer.decode(b[key][ib].cpu().numpy())
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
m = None
|
||||||
|
for i, b in tqdm(enumerate(dl)):
|
||||||
|
for ib in range(batch_sz):
|
||||||
|
#save(b, i, ib, 'paired_audio')
|
||||||
|
print(f'Paired text: {b["paired_text"][ib]}')
|
||||||
|
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
|
||||||
|
#save(b, i, ib, 'speech_audio')
|
||||||
|
print(f'Text: {b["text_text"][ib]}')
|
||||||
|
print(f'Text decoded: {decode(b, ib, "text_tokens")}')
|
||||||
|
|
|
@ -85,7 +85,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
self.needs_collate = opt_get(hparams, ['needs_collate'], True)
|
self.needs_collate = opt_get(hparams, ['needs_collate'], True)
|
||||||
if not self.needs_collate:
|
if not self.needs_collate:
|
||||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||||
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/custom_lowercase_gptvoice_tokenizer.json'))
|
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/custom_lowercase_gptvoice_tokenizer_r2.json'))
|
||||||
|
|
||||||
def get_wav_text_pair(self, audiopath_and_text):
|
def get_wav_text_pair(self, audiopath_and_text):
|
||||||
# separate filename and text
|
# separate filename and text
|
||||||
|
|
|
@ -143,6 +143,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
output = {
|
output = {
|
||||||
'clip': clips[0],
|
'clip': clips[0],
|
||||||
|
'clip_lengths': audio_norm.shape[-1],
|
||||||
'path': filename,
|
'path': filename,
|
||||||
}
|
}
|
||||||
if self.should_resample_clip:
|
if self.should_resample_clip:
|
||||||
|
|
|
@ -4,6 +4,7 @@ import datasets
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
from tokenizers.models import BPE
|
from tokenizers.models import BPE
|
||||||
from tokenizers.pre_tokenizers import Whitespace
|
from tokenizers.pre_tokenizers import Whitespace
|
||||||
|
from tokenizers.processors import ByteLevel
|
||||||
from tokenizers.trainers import BpeTrainer
|
from tokenizers.trainers import BpeTrainer
|
||||||
|
|
||||||
from data.audio.paired_voice_audio_dataset import load_mozilla_cv, load_voxpopuli, load_tsv
|
from data.audio.paired_voice_audio_dataset import load_mozilla_cv, load_voxpopuli, load_tsv
|
||||||
|
@ -33,9 +34,8 @@ def train():
|
||||||
with open('all_texts.txt', 'r', encoding='utf-8') as at:
|
with open('all_texts.txt', 'r', encoding='utf-8') as at:
|
||||||
ttsd = at.readlines()
|
ttsd = at.readlines()
|
||||||
bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
||||||
wkd = datasets.load_dataset('wikipedia', '20200501.en', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
|
||||||
|
|
||||||
allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—ʼ]+$')
|
allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—–ʼ]+$')
|
||||||
def preprocess_word(word, report=False):
|
def preprocess_word(word, report=False):
|
||||||
word = word.strip().lower()
|
word = word.strip().lower()
|
||||||
if not bool(allowed_characters_re.match(word)):
|
if not bool(allowed_characters_re.match(word)):
|
||||||
|
@ -53,14 +53,13 @@ def train():
|
||||||
for i in range(0, len(bcd), batch_size):
|
for i in range(0, len(bcd), batch_size):
|
||||||
yield [preprocess_word(t) for t in bcd[i:i+batch_size]['text']]
|
yield [preprocess_word(t) for t in bcd[i:i+batch_size]['text']]
|
||||||
|
|
||||||
print("Processing wikipedia.")
|
trainer = BpeTrainer(special_tokens=['[STOP]', '[UNK]'], vocab_size=9999, continuing_subword_prefix='$$$')
|
||||||
for i in range(0, len(wkd), batch_size):
|
|
||||||
yield [preprocess_word(t) for t in wkd[i:i+batch_size]['text']]
|
|
||||||
|
|
||||||
trainer = BpeTrainer(special_tokens=['[STOP]', '[UNK]'], vocab_size=9999)
|
|
||||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||||
tokenizer.pre_tokenizer = Whitespace()
|
tokenizer.pre_tokenizer = Whitespace()
|
||||||
tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(ttsd)+len(bcd)+len(wkd))
|
tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(ttsd)+len(bcd))
|
||||||
|
|
||||||
|
print(tokenizer.decode(tokenizer.encode("i was traveling throughhadslfghds the woods in 1235375t137{{}}").ids))
|
||||||
|
|
||||||
tokenizer.save('gpt_tts_tokenizer.json')
|
tokenizer.save('gpt_tts_tokenizer.json')
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,4 +73,4 @@ if __name__ == '__main__':
|
||||||
('Y:\\clips\\books2-transcribed.tsv', 'tsv'),
|
('Y:\\clips\\books2-transcribed.tsv', 'tsv'),
|
||||||
('Y:\\clips\\podcasts-0-transcribed.tsv', 'tsv')], 'all_texts.txt')
|
('Y:\\clips\\podcasts-0-transcribed.tsv', 'tsv')], 'all_texts.txt')
|
||||||
'''
|
'''
|
||||||
train()
|
train()
|
||||||
|
|
36
codes/data/text/hf_datasets_wrapper.py
Normal file
36
codes/data/text/hf_datasets_wrapper.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
|
||||||
|
class HfDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Simple wrapper for a HuggingFace dataset that can re-map keys if desired.
|
||||||
|
"""
|
||||||
|
def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'):
|
||||||
|
self.hfd = []
|
||||||
|
for corpus in corpi:
|
||||||
|
dataset_name, config = corpus
|
||||||
|
if config == '':
|
||||||
|
config = None
|
||||||
|
self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
|
||||||
|
self.key_maps = key_maps
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
for dataset in self.hfd:
|
||||||
|
if item < len(dataset):
|
||||||
|
val = dataset[item]
|
||||||
|
if self.key_maps is None:
|
||||||
|
return val
|
||||||
|
else:
|
||||||
|
return {k: val[v] for k, v in self.key_maps.items()}
|
||||||
|
else:
|
||||||
|
item -= len(dataset)
|
||||||
|
raise IndexError()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return sum([len(h) for h in self.hfd])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
|
||||||
|
print(d[5])
|
165
codes/models/gpt_voice/unified_voice.py
Normal file
165
codes/models/gpt_voice/unified_voice.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
import random
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
|
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||||
|
|
||||||
|
from models.arch_util import AttentionBlock
|
||||||
|
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
||||||
|
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||||
|
from models.tacotron2.text import symbols
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
spec_dim,
|
||||||
|
embedding_dim,
|
||||||
|
attn_blocks=6,
|
||||||
|
num_attn_heads=4,
|
||||||
|
do_checkpointing=False):
|
||||||
|
super().__init__()
|
||||||
|
attn = []
|
||||||
|
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||||
|
for a in range(attn_blocks):
|
||||||
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||||
|
self.attn = nn.Sequential(*attn)
|
||||||
|
self.dim = embedding_dim
|
||||||
|
self.do_checkpointing = do_checkpointing
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.init(x)
|
||||||
|
h = self.attn(h)
|
||||||
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class GptTtsHf(nn.Module):
|
||||||
|
NUMBER_TEXT_TOKENS = 10000 # The number of tokens produced by our bespoke BPE tokenizer.
|
||||||
|
START_TEXT_TOKEN = 9999
|
||||||
|
STOP_TEXT_TOKEN = 0
|
||||||
|
NUMBER_MEL_CODES = 8194
|
||||||
|
START_MEL_TOKEN = 8192
|
||||||
|
STOP_MEL_TOKEN = 8193
|
||||||
|
|
||||||
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=80, max_mel_tokens=250, max_conditioning_inputs=3,
|
||||||
|
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
self.max_mel_tokens = max_mel_tokens
|
||||||
|
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||||
|
self.model_dim = model_dim
|
||||||
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
|
self.mel_length_compression = mel_length_compression
|
||||||
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
|
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||||
|
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
||||||
|
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
||||||
|
n_positions=seq_length,
|
||||||
|
n_ctx=seq_length,
|
||||||
|
n_embd=model_dim,
|
||||||
|
n_layer=layers,
|
||||||
|
n_head=heads,
|
||||||
|
gradient_checkpointing=checkpointing,
|
||||||
|
use_cache=not checkpointing)
|
||||||
|
self.gpt = GPT2Model(self.gpt_config)
|
||||||
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
|
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||||
|
self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES)
|
||||||
|
self.max_conditioning_length = max_conditioning_length
|
||||||
|
|
||||||
|
|
||||||
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
|
inp = F.pad(input, (1,0), value=start_token)
|
||||||
|
tar = F.pad(input, (0,1), value=stop_token)
|
||||||
|
return inp, tar
|
||||||
|
|
||||||
|
def get_logits(self, text_inputs, cond_input, mel_inputs, get_attns=False):
|
||||||
|
text_emb = self.text_embedding(text_inputs)
|
||||||
|
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
|
||||||
|
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
||||||
|
|
||||||
|
emb = torch.cat([text_emb, cond, mel_emb], dim=1)
|
||||||
|
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||||
|
if get_attns:
|
||||||
|
return gpt_out.attentions
|
||||||
|
enc = gpt_out.last_hidden_state
|
||||||
|
|
||||||
|
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
|
||||||
|
text_logits = self.text_head(text_logits)
|
||||||
|
text_logits = text_logits.permute(0,2,1)
|
||||||
|
mel_logits = self.final_norm(enc[:, -mel_emb.shape[1]:])
|
||||||
|
mel_logits = self.mel_head(mel_logits)
|
||||||
|
mel_logits = mel_logits.permute(0,2,1)
|
||||||
|
|
||||||
|
return text_logits, mel_logits
|
||||||
|
|
||||||
|
def forward(self, text_inputs, cond_input, mel_targets, wav_lengths, return_attentions=False):
|
||||||
|
"""
|
||||||
|
Forward pass
|
||||||
|
text_inputs: long tensor, (b,t)
|
||||||
|
cond_inputs: MEL float tensor, (b,c,80,s)
|
||||||
|
mel_targets: long tensor, (b,m)
|
||||||
|
mel_lengths: long tensor, (b,)
|
||||||
|
"""
|
||||||
|
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||||
|
mel_lengths = wav_lengths // self.mel_length_compression
|
||||||
|
for b in range(len(mel_lengths)):
|
||||||
|
if mel_lengths[b] < mel_targets.shape[-1]:
|
||||||
|
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
|
||||||
|
|
||||||
|
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||||
|
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
|
||||||
|
if cond_input.shape[-1] > self.max_conditioning_length:
|
||||||
|
cond_input = cond_input[:,:,:self.max_conditioning_length]
|
||||||
|
|
||||||
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||||
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
|
||||||
|
text_logits, mel_logits = self.get_logits(text_inputs, cond_input, mel_inputs, get_attns=return_attentions)
|
||||||
|
if return_attentions:
|
||||||
|
return mel_logits
|
||||||
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
|
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||||
|
|
||||||
|
def inference(self, text_inputs, cond_input, **hf_generate_kwargs):
|
||||||
|
if not hasattr(self, 'inference_model'):
|
||||||
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
|
||||||
|
|
||||||
|
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||||
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||||
|
text_emb = self.text_embedding(text_inputs)
|
||||||
|
|
||||||
|
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||||
|
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
|
||||||
|
if cond_input.shape[-1] > self.max_conditioning_length:
|
||||||
|
cond_input = cond_input[:,:,:self.max_conditioning_length]
|
||||||
|
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
|
||||||
|
|
||||||
|
emb = torch.cat([text_emb, cond], dim=1)
|
||||||
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
||||||
|
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||||
|
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||||
|
|
||||||
|
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||||
|
max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs)
|
||||||
|
return gen[:, fake_inputs.shape[1]:]
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_gpt_tts_hf(opt_net, opt):
|
||||||
|
return GptTtsHf(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
gpt = GptTtsHf(model_dim=1024, heads=16)
|
||||||
|
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
|
||||||
|
torch.arange(0, 80, 1, dtype=torch.float).view(1,80,1).repeat(2,1,800),
|
||||||
|
torch.randint(high=8192, size=(2,250)),
|
||||||
|
torch.tensor([150*256,195*256]))
|
Loading…
Reference in New Issue
Block a user