experimental
This commit is contained in:
parent
31cfef59c4
commit
2e6a7625e4
8696
data/text_tokenizer.json
Normal file
8696
data/text_tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -16,17 +16,55 @@ from tokenizers.trainers import BpeTrainer
|
|||
from tokenizers.pre_tokenizers import Whitespace
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
input_metadata = "training/data"
|
||||
from vall_e.config import cfg
|
||||
from vall_e.utils.io import json_read
|
||||
from vall_e.emb.g2p import coerce_to_hiragana
|
||||
|
||||
output_file = Path("./training/tokenizer_training_data.json")
|
||||
input_metadata = "training/metadata/"
|
||||
|
||||
output_file = Path("./training/tokenizer_pretraining_data.json")
|
||||
tokenizer_data = []
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
name = str(dir)
|
||||
name = name.replace(str(cfg.data_dir), "")
|
||||
speaker_name = name
|
||||
"""
|
||||
if "LibriTTS-R" in speaker_name:
|
||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
||||
"""
|
||||
|
||||
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
|
||||
metadata = json_read( metadata_path, default={} )
|
||||
|
||||
for k, entry in metadata.items():
|
||||
if "text" not in entry:
|
||||
continue
|
||||
|
||||
language = entry.get('language','auto')
|
||||
text = entry['text']
|
||||
tokenizer_data.append( text )
|
||||
|
||||
if output_file.exists():
|
||||
tokenizer_data = json.loads(open(str(output_file), "r", encoding="utf-8").read())
|
||||
else:
|
||||
# training
|
||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||
try:
|
||||
add( data_dir, type="training" )
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# validation
|
||||
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
|
||||
try:
|
||||
add( data_dir, type="validation" )
|
||||
except Exception as e:
|
||||
pass
|
||||
"""
|
||||
for dataset_name in os.listdir(f'./{input_metadata}/'):
|
||||
if not os.path.isdir(f'./{input_metadata}/{dataset_name}/'):
|
||||
continue
|
||||
|
@ -42,17 +80,18 @@ else:
|
|||
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/{id}')
|
||||
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
||||
|
||||
if "phonemes" not in metadata:
|
||||
if "text" not in metadata:
|
||||
continue
|
||||
|
||||
tokenizer_data.append( f'{"".join(metadata["phonemes"])}' )
|
||||
tokenizer_data.append( f'{"".join(metadata["text"])}' )
|
||||
|
||||
open(output_file, 'w', encoding='utf-8').write(json.dumps(tokenizer_data))
|
||||
"""
|
||||
|
||||
unk_token = "<unk>"
|
||||
spl_tokens = [unk_token, "<bos>", "</eos>", "<mask>", "<space>"]
|
||||
|
||||
trainer = BpeTrainer(special_tokens = spl_tokens, vocab_size = 256)
|
||||
trainer = BpeTrainer(special_tokens = spl_tokens, vocab_size = 32768, max_token_length=1, min_frequency=len(tokenizer_data))
|
||||
tokenizer = Tokenizer(BPE(unk_token = unk_token))
|
||||
tokenizer.pre_tokenizer = Whitespace() # takes 2 hours to process without this, we'll just manually add spaces as a token
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
|
|
|
@ -267,7 +267,8 @@ class ModelExperimentalSettings:
|
|||
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
||||
|
||||
noncausal_masks: bool = False # to correct an oversight with Llama always using causal masks......
|
||||
classifiers_bias: bool = True # ugh
|
||||
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
|
||||
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
|
||||
|
||||
# classifier-free guidance training settings
|
||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||
|
@ -785,6 +786,9 @@ class Config(BaseConfig):
|
|||
|
||||
tokenizer: str | None = None # tokenizer class
|
||||
tokenizer_path: str = "./tokenizer.json" # tokenizer path
|
||||
|
||||
text_tokenizer: str | None = None # tokenizer class
|
||||
text_tokenizer_path: str = "./text_tokenizer.json" # tokenizer path
|
||||
|
||||
sample_rate: int = 24_000 # sample rate the model expects
|
||||
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
|
||||
|
@ -1053,6 +1057,21 @@ class Config(BaseConfig):
|
|||
raise Exception(f'Tokenizer path not found: {tokenizer_path}')
|
||||
|
||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||
|
||||
if self.tokenizer == "naive":
|
||||
...
|
||||
else:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
text_tokenizer_path = self.rel_path / self.text_tokenizer_path
|
||||
# deduce path if a local copy is not provided
|
||||
if not text_tokenizer_path.exists():
|
||||
text_tokenizer_path = Path("./data/") / self.text_tokenizer_path
|
||||
|
||||
if not self.silent_errors and not text_tokenizer_path.exists():
|
||||
raise Exception(f'Tokenizer path not found: {text_tokenizer_path}')
|
||||
|
||||
self.text_tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(text_tokenizer_path))
|
||||
|
||||
|
||||
# Preserves the old behavior
|
||||
|
|
|
@ -642,6 +642,11 @@ def tokenize( phones ):
|
|||
phones = "".join( phones )
|
||||
return cfg.tokenizer.encode( phones )
|
||||
|
||||
def text_tokenize( text ):
|
||||
if isinstance( text, list ):
|
||||
text = "".join( text )
|
||||
return cfg.text_tokenizer.encode( text )
|
||||
|
||||
def get_lang_symmap():
|
||||
return {
|
||||
"en": 0,
|
||||
|
@ -677,6 +682,9 @@ def get_task_symmap():
|
|||
"len": 0, # fake
|
||||
"nse": 6, # fake
|
||||
"cse": 6, # fake
|
||||
|
||||
"phn": 0, # fake
|
||||
"un-phn": 0, # fake
|
||||
}
|
||||
|
||||
def _replace_file_extension(path, suffix):
|
||||
|
@ -1299,6 +1307,8 @@ class Dataset(_Dataset):
|
|||
text_string = metadata["text"] if "text" in metadata else None
|
||||
|
||||
lang = self.get_language(spkr_group) if not lang else lang.lower()
|
||||
|
||||
raw_text = torch.tensor(text_tokenize(text_string)).to(torch.int16) if text_string else None
|
||||
|
||||
if not tone:
|
||||
tone = "neutral"
|
||||
|
@ -1388,6 +1398,8 @@ class Dataset(_Dataset):
|
|||
elif task == "len":
|
||||
proms = self.sample_prompts(spkr_name, reference=path)
|
||||
|
||||
elif task in ["phn", "un-phn"]:
|
||||
proms = []
|
||||
# noise suppression (<text>? <resp+noise> => <resp>)
|
||||
# speech removal (<text>?<resp+noise> => <noise>)
|
||||
elif task == "ns" or task == "sr":
|
||||
|
@ -1532,6 +1544,7 @@ class Dataset(_Dataset):
|
|||
text=text,
|
||||
proms=proms,
|
||||
resps=resps,
|
||||
raw_text=raw_text,
|
||||
|
||||
metadata=metadata,
|
||||
)
|
||||
|
|
|
@ -473,7 +473,7 @@ class TTS():
|
|||
)
|
||||
if model_len is not None:
|
||||
# skip calculating len_list if possible
|
||||
if task in ["ns, sr"]:
|
||||
if task in ["ns", "sr"]:
|
||||
len_list = [ prom[1].shape[0] ]
|
||||
elif vc_utterance is not None:
|
||||
len_list = [ vc_utterance.shape[0] ]
|
||||
|
|
|
@ -48,6 +48,7 @@ class AR_NAR(Base):
|
|||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list is not None:
|
||||
|
@ -194,6 +195,7 @@ class AR_NAR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
task_list=task_list,
|
||||
raw_text_list=raw_text_list,
|
||||
time_list=timesteps,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
|
@ -822,6 +824,7 @@ class AR_NAR(Base):
|
|||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
training: bool | None = None,
|
||||
|
||||
|
@ -860,6 +863,7 @@ class AR_NAR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
)
|
||||
|
||||
# is NAR
|
||||
|
|
|
@ -47,12 +47,14 @@ from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
|||
"""
|
||||
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
special_tasks = [ "len", "stt" ]
|
||||
special_tasks = [ "len", "stt", "phn", "un-phn" ]
|
||||
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
|
||||
task_outputs = {
|
||||
"tts": "resp",
|
||||
"stt": "text",
|
||||
"len": "len",
|
||||
"phn": "text",
|
||||
"un-phn": "raw_text",
|
||||
}
|
||||
|
||||
# yuck
|
||||
|
@ -187,7 +189,7 @@ class MultiEmbedding(nn.Module):
|
|||
class AudioEmbedding_Old(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||
l_embedding_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
|
||||
):
|
||||
|
@ -195,7 +197,7 @@ class AudioEmbedding_Old(nn.Module):
|
|||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
|
||||
|
@ -217,20 +219,20 @@ class AudioEmbedding_Old(nn.Module):
|
|||
class AudioEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
l_names: list[str] = [], # names to map to indices
|
||||
l_embedding_names: list[str] = [], # names to map to indices
|
||||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
# further experimentation is needed to see if this actually is useful
|
||||
self.sums = sums
|
||||
#
|
||||
self.names = l_names
|
||||
self.names = l_embedding_names
|
||||
|
||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
|
||||
if sums is None:
|
||||
|
@ -278,14 +280,14 @@ class TimeEmbedding(nn.Module):
|
|||
class Classifiers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
l_names: list[str] | None = None, # list of names to map to each classifier,
|
||||
l_embedding_names: list[str] | None = None, # list of names to map to each classifier,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for n_tokens in l_tokens])
|
||||
self.names = l_names
|
||||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for n_tokens in l_embedding_tokens])
|
||||
self.names = l_embedding_names
|
||||
|
||||
def indices(
|
||||
self,
|
||||
|
@ -326,7 +328,7 @@ class Classifiers(nn.Module):
|
|||
class Metrics(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: int | list[int],
|
||||
l_embedding_tokens: int | list[int],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
|
@ -339,14 +341,14 @@ class Metrics(nn.Module):
|
|||
average=average,
|
||||
multidim_average=multidim_average,
|
||||
ignore_index=ignore_index,
|
||||
) for n_tokens in l_tokens ])
|
||||
) for n_tokens in l_embedding_tokens ])
|
||||
self.precision = nn.ModuleList([ MulticlassPrecision(
|
||||
n_tokens,
|
||||
top_k=top_k,
|
||||
average=average,
|
||||
multidim_average=multidim_average,
|
||||
ignore_index=ignore_index,
|
||||
) for n_tokens in l_tokens ])
|
||||
) for n_tokens in l_embedding_tokens ])
|
||||
|
||||
def calc_accuracy( self, inputs, targets, classifier_levels ):
|
||||
return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs )
|
||||
|
@ -413,6 +415,7 @@ class Base(nn.Module):
|
|||
|
||||
n_text_tokens: int = 256,
|
||||
n_audio_tokens: int = 1024,
|
||||
n_raw_text_tokens: int = 8575,
|
||||
|
||||
d_model: int = 512,
|
||||
n_heads: int = 8,
|
||||
|
@ -434,6 +437,7 @@ class Base(nn.Module):
|
|||
|
||||
self.n_text_tokens = n_text_tokens
|
||||
self.n_audio_tokens = n_audio_tokens
|
||||
self.n_raw_text_tokens = n_raw_text_tokens
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
|
@ -477,6 +481,7 @@ class Base(nn.Module):
|
|||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
||||
classifiers_bias = self.config.experimental.classifiers_bias if self.config is not None else False
|
||||
max_position_embeddings = self.config.experimental.max_position_embeddings if self.config is not None else (75 * 60 * 5)
|
||||
|
||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||
|
@ -493,35 +498,44 @@ class Base(nn.Module):
|
|||
# pure AR
|
||||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
classifier_l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
if "ar" in self.capabilities:
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
else:
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
resp_l_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
# AR+NAR model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
|
||||
l_classifier_names = l_embedding_names
|
||||
|
||||
classifier_l_tokens += [ n_text_tokens ]
|
||||
classifier_l_names = resp_l_names + [ "stt" ]
|
||||
# STT
|
||||
l_classifier_names += [ "stt" ]
|
||||
l_classifier_tokens += [ n_text_tokens ]
|
||||
|
||||
# LEN
|
||||
if "len" in self.capabilities:
|
||||
classifier_l_tokens += [ 11 ]
|
||||
classifier_l_names += ["len"]
|
||||
l_classifier_tokens += [ 11 ]
|
||||
l_classifier_names += ["len"]
|
||||
|
||||
# TEXT => PHN / PHN => TEXT
|
||||
if self.version >= 6:
|
||||
l_classifier_tokens += [ n_raw_text_tokens ]
|
||||
l_classifier_names = l_embedding_names + [ "raw_text" ]
|
||||
|
||||
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||
|
||||
|
@ -541,6 +555,7 @@ class Base(nn.Module):
|
|||
"""
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.raw_text_emb = None
|
||||
self.langs_emb = None
|
||||
self.tones_emb = None
|
||||
self.tasks_emb = None
|
||||
|
@ -563,7 +578,7 @@ class Base(nn.Module):
|
|||
)
|
||||
# [1024 + STOP] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding_Old(
|
||||
l_tokens, d_model,
|
||||
l_embedding_tokens, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
else:
|
||||
|
@ -572,9 +587,9 @@ class Base(nn.Module):
|
|||
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
l_tokens, d_model,
|
||||
l_embedding_tokens, d_model,
|
||||
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
|
||||
l_names=resp_l_names,
|
||||
l_embedding_names=l_embedding_names,
|
||||
)
|
||||
|
||||
if self.version >= 3:
|
||||
|
@ -597,6 +612,9 @@ class Base(nn.Module):
|
|||
self.len_emb = Embedding(11, d_model)
|
||||
self.time_emb = None # TimeEmbedding(d_model) # if not masking_ratio else None
|
||||
|
||||
if self.version >= 6:
|
||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||
|
||||
if attention_backend == "auto":
|
||||
attention_backend = "sdpa"
|
||||
"""
|
||||
|
@ -635,7 +653,7 @@ class Base(nn.Module):
|
|||
self.model = MistralModel(MistralConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
|
@ -651,7 +669,7 @@ class Base(nn.Module):
|
|||
self.model = MixtralModel(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
|
@ -681,7 +699,7 @@ class Base(nn.Module):
|
|||
config = LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
|
@ -703,7 +721,7 @@ class Base(nn.Module):
|
|||
self.model = MixtralModel(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
|
@ -800,8 +818,8 @@ class Base(nn.Module):
|
|||
self.metrics = None
|
||||
else:
|
||||
self.classifier = None
|
||||
self.classifiers = Classifiers( classifier_l_tokens, d_model, l_names=classifier_l_names, bias=classifiers_bias )
|
||||
self.metrics = Metrics( classifier_l_tokens )
|
||||
self.classifiers = Classifiers( l_classifier_tokens, d_model, l_embedding_names=l_classifier_names, bias=classifiers_bias )
|
||||
self.metrics = Metrics( l_classifier_tokens )
|
||||
|
||||
"""
|
||||
if tie_classifier_to_embedding:
|
||||
|
@ -928,6 +946,7 @@ class Base(nn.Module):
|
|||
len_list: list[Tensor] | None = None,
|
||||
task_list: list[str] | None = None,
|
||||
time_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: int | list[int] | Tensor | None = None
|
||||
):
|
||||
|
@ -1042,6 +1061,34 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "stt") )
|
||||
# Text phonemizing task
|
||||
# Sequence: <raw_text><sep><lang><sep><phonemes>
|
||||
elif task_type == "phn":
|
||||
# insert the text prompt
|
||||
if raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "stt") )
|
||||
# Text de-phonemizing task
|
||||
# Sequence: <raw_text><sep><lang><sep><phonemes>
|
||||
elif task_type == "un-phn":
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# insert the text prompt
|
||||
if raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "raw_text") )
|
||||
else:
|
||||
raise Exception(f'Unrecognized task: {task_type}')
|
||||
return inputs
|
||||
|
@ -1149,6 +1196,10 @@ class Base(nn.Module):
|
|||
elif name == "text":
|
||||
embedding = self.text_emb( input )
|
||||
|
||||
device = embedding.device
|
||||
elif name == "raw_text":
|
||||
embedding = self.raw_text_emb( input )
|
||||
|
||||
device = embedding.device
|
||||
elif name == "quant_level" and self.rvq_l_emb is not None:
|
||||
embedding = self.rvq_l_emb( input )
|
||||
|
@ -1628,7 +1679,7 @@ class Base(nn.Module):
|
|||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||
casual_levels = [ "AR:0:0", "stt", "len" ]
|
||||
casual_levels = [ "AR:0:0", "stt", "len", "phn" ]
|
||||
|
||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
|
|
@ -38,6 +38,7 @@ def train_feeder(engine, batch, teacher=None):
|
|||
lang_list=batch["lang"],
|
||||
tone_list=batch["tone"],
|
||||
task_list=batch["task"],
|
||||
raw_text_list=batch["raw_text"],
|
||||
|
||||
training=True,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user