forked from mrq/DL-Art-School
support tts typing
This commit is contained in:
parent
48cb6a5abd
commit
8fe0dff33c
|
@ -48,6 +48,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
self.paths = [self.paths]
|
self.paths = [self.paths]
|
||||||
self.paths_size_bytes = [os.path.getsize(p) for p in self.paths]
|
self.paths_size_bytes = [os.path.getsize(p) for p in self.paths]
|
||||||
self.total_size_bytes = sum(self.paths_size_bytes)
|
self.total_size_bytes = sum(self.paths_size_bytes)
|
||||||
|
self.types = opt_get(hparams, ['types'], [0 for _ in self.paths])
|
||||||
|
|
||||||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||||
|
@ -101,22 +102,23 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
else:
|
else:
|
||||||
rand_offset -= self.paths_size_bytes[i]
|
rand_offset -= self.paths_size_bytes[i]
|
||||||
path = self.paths[i]
|
path = self.paths[i]
|
||||||
|
type = self.types[i]
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
f.seek(rand_offset)
|
f.seek(rand_offset)
|
||||||
# Read the rest of the line we seeked to, then the line after that.
|
# Read the rest of the line we seeked to, then the line after that.
|
||||||
try: # This can fail when seeking to a UTF-8 escape byte.
|
try: # This can fail when seeking to a UTF-8 escape byte.
|
||||||
f.readline()
|
f.readline()
|
||||||
except:
|
except:
|
||||||
return self.load_random_line(depth=depth + 1) # On failure, just recurse and try again.
|
return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again.
|
||||||
l2 = f.readline()
|
l2 = f.readline()
|
||||||
|
|
||||||
if l2:
|
if l2:
|
||||||
try:
|
try:
|
||||||
base_path = os.path.dirname(path)
|
base_path = os.path.dirname(path)
|
||||||
return parse_tsv_aligned_codes(l2, base_path)
|
return parse_tsv_aligned_codes(l2, base_path), type
|
||||||
except:
|
except:
|
||||||
print(f"error parsing random offset: {sys.exc_info()}")
|
print(f"error parsing random offset: {sys.exc_info()}")
|
||||||
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
|
return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again.
|
||||||
|
|
||||||
def get_ctc_metadata(self, codes):
|
def get_ctc_metadata(self, codes):
|
||||||
grouped = groupby(codes.tolist())
|
grouped = groupby(codes.tolist())
|
||||||
|
@ -155,7 +157,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
self.skipped_items += 1
|
self.skipped_items += 1
|
||||||
apt = self.load_random_line()
|
apt, type = self.load_random_line()
|
||||||
try:
|
try:
|
||||||
tseq, wav, text, path = self.get_wav_text_pair(apt)
|
tseq, wav, text, path = self.get_wav_text_pair(apt)
|
||||||
if text is None or len(text.strip()) == 0:
|
if text is None or len(text.strip()) == 0:
|
||||||
|
@ -204,7 +206,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||||
'filenames': path,
|
'filenames': path,
|
||||||
'skipped_items': actually_skipped_items,
|
'skipped_items': actually_skipped_items,
|
||||||
'load_time': self.load_times.mean()
|
'load_time': self.load_times.mean(),
|
||||||
|
'type': type,
|
||||||
}
|
}
|
||||||
if self.load_conditioning:
|
if self.load_conditioning:
|
||||||
res['conditioning'] = cond
|
res['conditioning'] = cond
|
||||||
|
@ -266,10 +269,11 @@ if __name__ == '__main__':
|
||||||
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
|
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
|
||||||
'y:/libritts/train-clean-100/transcribed-oco.tsv',
|
'y:/libritts/train-clean-100/transcribed-oco.tsv',
|
||||||
'y:/libritts/train-clean-360/transcribed-oco.tsv',
|
'y:/libritts/train-clean-360/transcribed-oco.tsv',
|
||||||
'y:/clips/books1/transcribed-w2v.tsv',
|
'y:/clips/books1/transcribed-oco.tsv',
|
||||||
'y:/clips/books2/transcribed-w2v.tsv',
|
'y:/clips/books2/transcribed-oco.tsv',
|
||||||
'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv',
|
'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
|
||||||
'y:/clips/podcasts-1/transcribed-oco.tsv',],
|
'y:/clips/podcasts-1/transcribed-oco.tsv',],
|
||||||
|
'types': [0,1,1,1,2,2,0],
|
||||||
'phase': 'train',
|
'phase': 'train',
|
||||||
'n_workers': 0,
|
'n_workers': 0,
|
||||||
'batch_size': batch_sz,
|
'batch_size': batch_sz,
|
||||||
|
|
|
@ -9,11 +9,28 @@ import torchaudio
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
from models.audio.tts.tacotron2 import load_filepaths_and_text, load_filepaths_and_text_type
|
||||||
from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text
|
from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def load_tsv_type(filename, type):
|
||||||
|
with open(filename, encoding='utf-8') as f:
|
||||||
|
filepaths_and_text = []
|
||||||
|
base = os.path.dirname(filename)
|
||||||
|
bad_lines = 0
|
||||||
|
for line in f:
|
||||||
|
components = line.strip().split('\t')
|
||||||
|
if len(components) < 2:
|
||||||
|
bad_lines += 1
|
||||||
|
if bad_lines > 1000:
|
||||||
|
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||||
|
raise ValueError
|
||||||
|
continue
|
||||||
|
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]] + [type])
|
||||||
|
return filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
def load_tsv(filename):
|
def load_tsv(filename):
|
||||||
with open(filename, encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
filepaths_and_text = []
|
filepaths_and_text = []
|
||||||
|
@ -40,6 +57,23 @@ def convert_string_list_to_tensor(strlist):
|
||||||
return torch.tensor(as_ints)
|
return torch.tensor(as_ints)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tsv_aligned_codes_type(filename, type):
|
||||||
|
with open(filename, encoding='utf-8') as f:
|
||||||
|
filepaths_and_text = []
|
||||||
|
base = os.path.dirname(filename)
|
||||||
|
bad_lines = 0
|
||||||
|
for line in f:
|
||||||
|
components = line.strip().split('\t')
|
||||||
|
if len(components) < 3:
|
||||||
|
bad_lines += 1
|
||||||
|
if bad_lines > 1000:
|
||||||
|
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||||
|
raise ValueError
|
||||||
|
continue
|
||||||
|
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type])
|
||||||
|
return filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
def load_tsv_aligned_codes(filename):
|
def load_tsv_aligned_codes(filename):
|
||||||
with open(filename, encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
filepaths_and_text = []
|
filepaths_and_text = []
|
||||||
|
@ -57,15 +91,15 @@ def load_tsv_aligned_codes(filename):
|
||||||
return filepaths_and_text
|
return filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
def load_mozilla_cv(filename):
|
def load_mozilla_cv(filename, type):
|
||||||
with open(filename, encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
components = [line.strip().split('\t') for line in f][1:] # First line is the header
|
components = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||||
base = os.path.dirname(filename)
|
base = os.path.dirname(filename)
|
||||||
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2]] for component in components]
|
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2], type] for component in components]
|
||||||
return filepaths_and_text
|
return filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
def load_voxpopuli(filename):
|
def load_voxpopuli(filename, type):
|
||||||
with open(filename, encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
lines = [line.strip().split('\t') for line in f][1:] # First line is the header
|
lines = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||||
base = os.path.dirname(filename)
|
base = os.path.dirname(filename)
|
||||||
|
@ -75,7 +109,7 @@ def load_voxpopuli(filename):
|
||||||
continue
|
continue
|
||||||
file, raw_text, norm_text, speaker_id, split, gender = line
|
file, raw_text, norm_text, speaker_id, split, gender = line
|
||||||
year = file[:4]
|
year = file[:4]
|
||||||
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text])
|
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type])
|
||||||
return filepaths_and_text
|
return filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,6 +126,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
self.path = hparams['path']
|
self.path = hparams['path']
|
||||||
if not isinstance(self.path, list):
|
if not isinstance(self.path, list):
|
||||||
self.path = [self.path]
|
self.path = [self.path]
|
||||||
|
self.types = opt_get(hparams, ['types'], [0 for _ in self.path])
|
||||||
|
|
||||||
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
|
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
|
||||||
if not isinstance(fetcher_mode, list):
|
if not isinstance(fetcher_mode, list):
|
||||||
|
@ -105,11 +140,11 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
|
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
|
||||||
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
|
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
|
||||||
self.audiopaths_and_text = []
|
self.audiopaths_and_text = []
|
||||||
for p, fm in zip(self.path, fetcher_mode):
|
for p, fm, type in zip(self.path, fetcher_mode, self.types):
|
||||||
if fm == 'lj' or fm == 'libritts':
|
if fm == 'lj' or fm == 'libritts':
|
||||||
fetcher_fn = load_filepaths_and_text
|
fetcher_fn = load_filepaths_and_text_type
|
||||||
elif fm == 'tsv':
|
elif fm == 'tsv':
|
||||||
fetcher_fn = load_tsv_aligned_codes if self.load_aligned_codes else load_tsv
|
fetcher_fn = load_tsv_aligned_codes_type if self.load_aligned_codes else load_tsv_type
|
||||||
elif fm == 'mozilla_cv':
|
elif fm == 'mozilla_cv':
|
||||||
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
|
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
|
||||||
fetcher_fn = load_mozilla_cv
|
fetcher_fn = load_mozilla_cv
|
||||||
|
@ -118,7 +153,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
fetcher_fn = load_voxpopuli
|
fetcher_fn = load_voxpopuli
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
self.audiopaths_and_text.extend(fetcher_fn(p))
|
self.audiopaths_and_text.extend(fetcher_fn(p, type))
|
||||||
self.text_cleaners = hparams.text_cleaners
|
self.text_cleaners = hparams.text_cleaners
|
||||||
self.sample_rate = hparams.sample_rate
|
self.sample_rate = hparams.sample_rate
|
||||||
random.seed(hparams.seed)
|
random.seed(hparams.seed)
|
||||||
|
@ -138,10 +173,10 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def get_wav_text_pair(self, audiopath_and_text):
|
def get_wav_text_pair(self, audiopath_and_text):
|
||||||
# separate filename and text
|
# separate filename and text
|
||||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
audiopath, text, type = audiopath_and_text[0], audiopath_and_text[1], audiopath_and_text[2]
|
||||||
text_seq = self.get_text(text)
|
text_seq = self.get_text(text)
|
||||||
wav = load_audio(audiopath, self.sample_rate)
|
wav = load_audio(audiopath, self.sample_rate)
|
||||||
return (text_seq, wav, text, audiopath_and_text[0])
|
return (text_seq, wav, text, audiopath_and_text[0], type)
|
||||||
|
|
||||||
def get_text(self, text):
|
def get_text(self, text):
|
||||||
tokens = self.tokenizer.encode(text)
|
tokens = self.tokenizer.encode(text)
|
||||||
|
@ -156,7 +191,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
self.skipped_items += 1
|
self.skipped_items += 1
|
||||||
try:
|
try:
|
||||||
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
tseq, wav, text, path, type = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
||||||
if text is None or len(text.strip()) == 0:
|
if text is None or len(text.strip()) == 0:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
if wav is None or wav.shape[-1] < (.6 * self.sample_rate):
|
if wav is None or wav.shape[-1] < (.6 * self.sample_rate):
|
||||||
|
@ -202,6 +237,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||||
'filenames': path,
|
'filenames': path,
|
||||||
'skipped_items': actually_skipped_items,
|
'skipped_items': actually_skipped_items,
|
||||||
|
'type': type,
|
||||||
}
|
}
|
||||||
if self.load_conditioning:
|
if self.load_conditioning:
|
||||||
res['conditioning'] = cond
|
res['conditioning'] = cond
|
||||||
|
@ -249,7 +285,7 @@ if __name__ == '__main__':
|
||||||
batch_sz = 8
|
batch_sz = 8
|
||||||
params = {
|
params = {
|
||||||
'mode': 'paired_voice_audio',
|
'mode': 'paired_voice_audio',
|
||||||
'path': ['Y:\\clips\\books1\\transcribed-w2v.tsv'],
|
'path': ['Y:\\clips\\books1\\transcribed-oco.tsv'],
|
||||||
'fetcher_mode': ['tsv'],
|
'fetcher_mode': ['tsv'],
|
||||||
'phase': 'train',
|
'phase': 'train',
|
||||||
'n_workers': 0,
|
'n_workers': 0,
|
||||||
|
|
|
@ -26,6 +26,14 @@ def load_wav_to_torch(full_path):
|
||||||
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def load_filepaths_and_text_type(filename, type, split="|"):
|
||||||
|
with open(filename, encoding='utf-8') as f:
|
||||||
|
filepaths_and_text = [list(line.strip().split(split)) + [type] for line in f]
|
||||||
|
base = os.path.dirname(filename)
|
||||||
|
for j in range(len(filepaths_and_text)):
|
||||||
|
filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])
|
||||||
|
return filepaths_and_text
|
||||||
|
|
||||||
def load_filepaths_and_text(filename, split="|"):
|
def load_filepaths_and_text(filename, split="|"):
|
||||||
with open(filename, encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||||
|
|
|
@ -238,10 +238,10 @@ class MelEncoder(nn.Module):
|
||||||
|
|
||||||
class UnifiedVoice(nn.Module):
|
class UnifiedVoice(nn.Module):
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||||
mel_length_compression=1024, number_text_tokens=256,
|
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
|
||||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False):
|
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False,
|
||||||
|
types=1):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
layers: Number of layers in transformer stack.
|
layers: Number of layers in transformer stack.
|
||||||
|
@ -252,7 +252,6 @@ class UnifiedVoice(nn.Module):
|
||||||
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
||||||
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
||||||
number_text_tokens:
|
number_text_tokens:
|
||||||
start_text_token:
|
|
||||||
stop_text_token:
|
stop_text_token:
|
||||||
number_mel_codes:
|
number_mel_codes:
|
||||||
start_mel_token:
|
start_mel_token:
|
||||||
|
@ -265,8 +264,8 @@ class UnifiedVoice(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
self.start_text_token = start_text_token
|
self.start_text_token = number_text_tokens * types
|
||||||
self.stop_text_token = stop_text_token
|
self.stop_text_token = 0
|
||||||
self.number_mel_codes = number_mel_codes
|
self.number_mel_codes = number_mel_codes
|
||||||
self.start_mel_token = start_mel_token
|
self.start_mel_token = start_mel_token
|
||||||
self.stop_mel_token = stop_mel_token
|
self.stop_mel_token = stop_mel_token
|
||||||
|
@ -279,7 +278,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||||
else:
|
else:
|
||||||
|
@ -294,7 +293,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.text_solo_embedding = 0
|
self.text_solo_embedding = 0
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
|
@ -366,7 +365,7 @@ class UnifiedVoice(nn.Module):
|
||||||
else:
|
else:
|
||||||
return first_logits
|
return first_logits
|
||||||
|
|
||||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
|
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
||||||
return_latent=False, clip_inputs=True):
|
return_latent=False, clip_inputs=True):
|
||||||
"""
|
"""
|
||||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||||
|
@ -383,6 +382,10 @@ class UnifiedVoice(nn.Module):
|
||||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||||
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
||||||
"""
|
"""
|
||||||
|
# Types are expressed by expanding the text embedding space.
|
||||||
|
if types is not None:
|
||||||
|
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||||
|
|
||||||
if clip_inputs:
|
if clip_inputs:
|
||||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
# chopping the inputs by the maximum actual length.
|
# chopping the inputs by the maximum actual length.
|
||||||
|
@ -579,10 +582,11 @@ def register_unified_voice2(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
|
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True, types=2)
|
||||||
l = gpt(torch.randn(2, 3, 80, 800),
|
l = gpt(torch.randn(2, 3, 80, 800),
|
||||||
torch.randint(high=256, size=(2,120)),
|
torch.randint(high=256, size=(2,120)),
|
||||||
torch.tensor([32, 120]),
|
torch.tensor([32, 120]),
|
||||||
torch.randint(high=8192, size=(2,250)),
|
torch.randint(high=8192, size=(2,250)),
|
||||||
torch.tensor([250*256,195*256]))
|
torch.tensor([250*256,195*256]),
|
||||||
|
types=torch.tensor([0, 1]))
|
||||||
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user