experimental
This commit is contained in:
parent
2e6a7625e4
commit
db6cf4c969
|
@ -249,11 +249,6 @@ class ModelExperimentalSettings:
|
||||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||||
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||||
|
|
||||||
# performs token dropout to compensate for errors
|
|
||||||
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
|
||||||
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
|
||||||
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
|
||||||
|
|
||||||
causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions
|
causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions
|
||||||
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
|
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
|
||||||
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
|
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
|
||||||
|
@ -270,11 +265,19 @@ class ModelExperimentalSettings:
|
||||||
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
|
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
|
||||||
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
|
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
|
||||||
|
|
||||||
|
# these technically should be as hyperparameters
|
||||||
|
# performs token dropout to compensate for errors
|
||||||
|
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
||||||
|
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
||||||
|
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
||||||
|
# these technically should be as hyperparameters
|
||||||
# classifier-free guidance training settings
|
# classifier-free guidance training settings
|
||||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||||
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
||||||
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
|
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
|
||||||
|
|
||||||
|
use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead
|
||||||
|
|
||||||
# failed experiment
|
# failed experiment
|
||||||
layerskip: bool = False # layerskip compatible model (or training for)
|
layerskip: bool = False # layerskip compatible model (or training for)
|
||||||
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
||||||
|
|
|
@ -23,7 +23,7 @@ from .config import cfg, Config
|
||||||
from .models import get_models
|
from .models import get_models
|
||||||
from .models.lora import enable_lora
|
from .models.lora import enable_lora
|
||||||
from .engines import load_engines, deepspeed_available
|
from .engines import load_engines, deepspeed_available
|
||||||
from .data import get_phone_symmap, get_lang_symmap, tokenize, sentence_split
|
from .data import get_phone_symmap, get_lang_symmap, tokenize, text_tokenize, sentence_split
|
||||||
from .models import download_model, DEFAULT_MODEL_PATH
|
from .models import download_model, DEFAULT_MODEL_PATH
|
||||||
|
|
||||||
if deepspeed_available:
|
if deepspeed_available:
|
||||||
|
@ -412,7 +412,7 @@ class TTS():
|
||||||
model = model_ar if model_ar is not None else model_nar
|
model = model_ar if model_ar is not None else model_nar
|
||||||
if model is not None:
|
if model is not None:
|
||||||
text_list = model(
|
text_list = model(
|
||||||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
|
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task],
|
||||||
disable_tqdm=not use_tqdm,
|
disable_tqdm=not use_tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
|
@ -423,6 +423,35 @@ class TTS():
|
||||||
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
||||||
|
|
||||||
return text_list[0]
|
return text_list[0]
|
||||||
|
elif task in ["phn", "un-phn"]:
|
||||||
|
lang = self.encode_lang( language )
|
||||||
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||||
|
model = model_ar if model_ar is not None else model_nar
|
||||||
|
if task == "phn":
|
||||||
|
text_list = None
|
||||||
|
raw_text_list = [ torch.tensor( text_tokenize( text ), device=self.device, dtype=torch.int16) ]
|
||||||
|
output_tokenizer = cfg.tokenizer
|
||||||
|
else:
|
||||||
|
text_list = [ torch.tensor( tokenize( text ), device=self.device, dtype=torch.int16) ]
|
||||||
|
raw_text_list = None
|
||||||
|
output_tokenizer = cfg.text_tokenizer
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
text_list = model(
|
||||||
|
text_list=text_list, raw_text_list=raw_text_list, lang_list=[lang], task_list=[task],
|
||||||
|
disable_tqdm=not use_tqdm,
|
||||||
|
use_lora=use_lora,
|
||||||
|
**sampling_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("!")
|
||||||
|
|
||||||
|
text_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
||||||
|
|
||||||
|
return text_list[0]
|
||||||
|
|
||||||
|
|
||||||
# stuff for rolling context
|
# stuff for rolling context
|
||||||
prefix_context = None
|
prefix_context = None
|
||||||
|
|
|
@ -29,7 +29,7 @@ from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs
|
||||||
from .lora import enable_lora
|
from .lora import enable_lora
|
||||||
from ..samplers import cfg_logits
|
from ..samplers import cfg_logits
|
||||||
|
|
||||||
text_task = [ "stt" ]
|
text_task = [ "stt", "phn", "un-phn" ]
|
||||||
|
|
||||||
class AR_NAR(Base):
|
class AR_NAR(Base):
|
||||||
# yikes
|
# yikes
|
||||||
|
@ -40,23 +40,28 @@ class AR_NAR(Base):
|
||||||
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
|
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
|
||||||
def forward_train(
|
def forward_train(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
|
||||||
proms_list: list[Tensor],
|
|
||||||
resps_list: list[Tensor],
|
|
||||||
|
|
||||||
task_list: list[Tensor] | None = None,
|
task_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
|
text_list: list[Tensor] | None = None,
|
||||||
|
proms_list: list[Tensor] | None = None,
|
||||||
|
resps_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
raw_text_list: list[Tensor] | None = None,
|
raw_text_list: list[Tensor] | None = None,
|
||||||
):
|
):
|
||||||
# deduce batch_size
|
# deduce batch_size
|
||||||
if text_list is not None:
|
if text_list:
|
||||||
default_task = "tts"
|
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
else:
|
elif raw_text_list:
|
||||||
default_task = "stt"
|
device = raw_text_list[0].device
|
||||||
|
batch_size = len(raw_text_list)
|
||||||
|
elif proms_list:
|
||||||
|
device = proms_list[0].device
|
||||||
|
batch_size = len(proms_list)
|
||||||
|
elif resps_list:
|
||||||
device = resps_list[0].device
|
device = resps_list[0].device
|
||||||
batch_size = len(resps_list)
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
|
@ -161,10 +166,7 @@ class AR_NAR(Base):
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if quant_level <= 0 and timesteps[i] is None:
|
if quant_level <= 0 and timesteps[i] is None:
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
if task in text_task:
|
if task not in text_task:
|
||||||
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
|
||||||
...
|
|
||||||
else:
|
|
||||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||||
|
|
||||||
if task == "len":
|
if task == "len":
|
||||||
|
@ -174,6 +176,7 @@ class AR_NAR(Base):
|
||||||
if task not in text_task + ["len"]:
|
if task not in text_task + ["len"]:
|
||||||
drop_text = False
|
drop_text = False
|
||||||
drop_audio = False
|
drop_audio = False
|
||||||
|
swap_text = False
|
||||||
|
|
||||||
if random.random() < cfg_prom_dropout_p:
|
if random.random() < cfg_prom_dropout_p:
|
||||||
drop_audio = True
|
drop_audio = True
|
||||||
|
@ -182,12 +185,18 @@ class AR_NAR(Base):
|
||||||
drop_audio = True
|
drop_audio = True
|
||||||
drop_text = True
|
drop_text = True
|
||||||
|
|
||||||
|
if random.random() < use_raw_text_p and raw_text_list[i] is not None:
|
||||||
|
swap_text = True
|
||||||
|
|
||||||
if drop_text:
|
if drop_text:
|
||||||
text_list[i] = text_start_stop_sequence
|
text_list[i] = text_start_stop_sequence
|
||||||
|
|
||||||
if drop_audio:
|
if drop_audio:
|
||||||
proms_list[i] = None
|
proms_list[i] = None
|
||||||
|
|
||||||
|
if swap_text and not drop_text:
|
||||||
|
text_list[i] = None
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
|
@ -209,14 +218,16 @@ class AR_NAR(Base):
|
||||||
def forward_nar_masked(
|
def forward_nar_masked(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
text_list: list[Tensor],
|
task_list: list[Tensor] | None = None,
|
||||||
proms_list: list[Tensor],
|
|
||||||
|
text_list: list[Tensor] | None = None,
|
||||||
|
proms_list: list[Tensor] | None = None,
|
||||||
resps_list: list[Tensor] | None = None,
|
resps_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
task_list: list[Tensor] | None = None,
|
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
raw_text_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
use_lora=None,
|
use_lora=None,
|
||||||
|
@ -420,15 +431,18 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
def forward_nar(
|
def forward_nar(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
task_list: list[Tensor] | None = None,
|
||||||
proms_list: list[Tensor],
|
|
||||||
|
text_list: list[Tensor] | None = None,
|
||||||
|
proms_list: list[Tensor] | None = None,
|
||||||
resps_list: list[Tensor] | None = None,
|
resps_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
task_list: list[Tensor] | None = None,
|
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
|
raw_text_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
use_lora=None,
|
use_lora=None,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
|
@ -447,12 +461,16 @@ class AR_NAR(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
# deduce batch_size
|
# deduce batch_size
|
||||||
if text_list is not None:
|
if text_list:
|
||||||
default_task = "tts"
|
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
else:
|
elif raw_text_list:
|
||||||
default_task = "stt"
|
device = raw_text_list[0].device
|
||||||
|
batch_size = len(raw_text_list)
|
||||||
|
elif proms_list:
|
||||||
|
device = proms_list[0].device
|
||||||
|
batch_size = len(proms_list)
|
||||||
|
elif resps_list:
|
||||||
device = resps_list[0].device
|
device = resps_list[0].device
|
||||||
batch_size = len(resps_list)
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
|
@ -534,25 +552,31 @@ class AR_NAR(Base):
|
||||||
def forward_ar(
|
def forward_ar(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
text_list: list[Tensor],
|
task_list: list[Tensor],
|
||||||
proms_list: list[Tensor],
|
|
||||||
resps_list: list[Tensor] | None = None,
|
|
||||||
|
|
||||||
task_list: list[Tensor] | None = None,
|
text_list: list[Tensor] | None = None,
|
||||||
|
raw_text_list: list[Tensor] | None = None,
|
||||||
|
proms_list: list[Tensor] | None = None,
|
||||||
|
resps_list: list[Tensor] | None = None,
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
use_lora=None,
|
use_lora=None,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
):
|
):
|
||||||
# deduce batch_size
|
# deduce batch_size
|
||||||
if text_list is not None:
|
if text_list:
|
||||||
default_task = "tts"
|
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
else:
|
elif raw_text_list:
|
||||||
default_task = "stt"
|
device = raw_text_list[0].device
|
||||||
|
batch_size = len(raw_text_list)
|
||||||
|
elif proms_list:
|
||||||
|
device = proms_list[0].device
|
||||||
|
batch_size = len(proms_list)
|
||||||
|
elif resps_list:
|
||||||
device = resps_list[0].device
|
device = resps_list[0].device
|
||||||
batch_size = len(resps_list)
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
|
@ -590,13 +614,17 @@ class AR_NAR(Base):
|
||||||
len_list = sequence_list
|
len_list = sequence_list
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
|
task_list=task_list,
|
||||||
|
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
|
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
len_list=len_list,
|
len_list=len_list,
|
||||||
task_list=task_list,
|
raw_text_list=raw_text_list,
|
||||||
|
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -627,7 +655,6 @@ class AR_NAR(Base):
|
||||||
# convert tokens into int
|
# convert tokens into int
|
||||||
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
|
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
|
||||||
|
|
||||||
# STT
|
|
||||||
start_slice = [ 0 for _ in range(batch_size) ]
|
start_slice = [ 0 for _ in range(batch_size) ]
|
||||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
@ -684,19 +711,29 @@ class AR_NAR(Base):
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
|
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
|
||||||
for n in iterator:
|
for n in iterator:
|
||||||
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
if batch_size == 1 and task_list[0] in ["phn", "un-phn"]:
|
||||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ]
|
||||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ]
|
||||||
|
else:
|
||||||
|
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||||
|
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||||
|
|
||||||
|
print( text_list, raw_text_list )
|
||||||
|
|
||||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
|
task_list=task_list,
|
||||||
|
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
|
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
len_list=len_list,
|
len_list=len_list,
|
||||||
task_list=task_list,
|
raw_text_list=raw_text_list,
|
||||||
|
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -816,11 +853,12 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
task_list: list[Tensor] | None = None,
|
||||||
proms_list: list[Tensor],
|
|
||||||
|
text_list: list[Tensor] | None = None,
|
||||||
|
proms_list: list[Tensor] | None = None,
|
||||||
resps_list: list[Tensor] | None = None,
|
resps_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
task_list: list[Tensor] | None = None,
|
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
@ -833,19 +871,20 @@ class AR_NAR(Base):
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
):
|
):
|
||||||
# deduce batch_size
|
# deduce batch_size
|
||||||
if text_list is not None:
|
# deduce batch_size
|
||||||
default_task = "tts"
|
if text_list:
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
else:
|
elif raw_text_list:
|
||||||
default_task = "stt"
|
device = raw_text_list[0].device
|
||||||
|
batch_size = len(raw_text_list)
|
||||||
|
elif proms_list:
|
||||||
|
device = proms_list[0].device
|
||||||
|
batch_size = len(proms_list)
|
||||||
|
elif resps_list:
|
||||||
device = resps_list[0].device
|
device = resps_list[0].device
|
||||||
batch_size = len(resps_list)
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
# generate task list if not provided
|
|
||||||
if task_list is None:
|
|
||||||
task_list = [ default_task for _ in range(batch_size) ]
|
|
||||||
|
|
||||||
# implicitly set for training
|
# implicitly set for training
|
||||||
if training is None and text_list is not None and resps_list is not None:
|
if training is None and text_list is not None and resps_list is not None:
|
||||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||||
|
@ -856,10 +895,12 @@ class AR_NAR(Base):
|
||||||
# is training
|
# is training
|
||||||
if training:
|
if training:
|
||||||
return self.forward_train(
|
return self.forward_train(
|
||||||
|
task_list=task_list,
|
||||||
|
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
task_list=task_list,
|
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
len_list=len_list,
|
len_list=len_list,
|
||||||
|
@ -869,13 +910,17 @@ class AR_NAR(Base):
|
||||||
# is NAR
|
# is NAR
|
||||||
if (len_list is not None or resps_list is not None) and text_list is not None:
|
if (len_list is not None or resps_list is not None) and text_list is not None:
|
||||||
return self.forward_nar(
|
return self.forward_nar(
|
||||||
|
task_list=task_list,
|
||||||
|
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
task_list=task_list,
|
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
len_list=len_list,
|
len_list=len_list,
|
||||||
|
raw_text_list=raw_text_list,
|
||||||
|
|
||||||
disable_tqdm=disable_tqdm,
|
disable_tqdm=disable_tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
|
@ -883,13 +928,17 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# is AR
|
# is AR
|
||||||
return self.forward_ar(
|
return self.forward_ar(
|
||||||
|
task_list=task_list,
|
||||||
|
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
task_list=task_list,
|
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
len_list=len_list,
|
len_list=len_list,
|
||||||
|
raw_text_list=raw_text_list,
|
||||||
|
|
||||||
disable_tqdm=disable_tqdm,
|
disable_tqdm=disable_tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
|
|
|
@ -937,21 +937,32 @@ class Base(nn.Module):
|
||||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||||
def inputs(
|
def inputs(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor] | None = None,
|
||||||
proms_list: list[Tensor],
|
raw_text_list: list[Tensor] | None = None,
|
||||||
resps_list: list[Tensor],
|
|
||||||
|
proms_list: list[Tensor] | None = None,
|
||||||
|
resps_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
task_list: list[str] | None = None,
|
task_list: list[str] | None = None,
|
||||||
time_list: list[Tensor] | None = None,
|
time_list: list[Tensor] | None = None,
|
||||||
raw_text_list: list[Tensor] | None = None,
|
|
||||||
|
|
||||||
quant_levels: int | list[int] | Tensor | None = None
|
quant_levels: int | list[int] | Tensor | None = None
|
||||||
):
|
):
|
||||||
device = text_list[0].device
|
if text_list:
|
||||||
batch_size = len(text_list)
|
device = text_list[0].device
|
||||||
|
batch_size = len(text_list)
|
||||||
|
elif raw_text_list:
|
||||||
|
device = raw_text_list[0].device
|
||||||
|
batch_size = len(raw_text_list)
|
||||||
|
elif proms_list:
|
||||||
|
device = proms_list[0].device
|
||||||
|
batch_size = len(proms_list)
|
||||||
|
elif resps_list:
|
||||||
|
device = resps_list[0].device
|
||||||
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
inputs = [ [] for _ in range(batch_size) ]
|
inputs = [ [] for _ in range(batch_size) ]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
@ -973,6 +984,8 @@ class Base(nn.Module):
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None and text_list[i] is not None:
|
if text_list is not None and text_list[i] is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
elif raw_text_list is not None and raw_text_list[i] is not None:
|
||||||
|
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||||
# insert lang token if we're trained for it
|
# insert lang token if we're trained for it
|
||||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||||
|
@ -1022,6 +1035,8 @@ class Base(nn.Module):
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None and text_list[i] is not None:
|
if text_list is not None and text_list[i] is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
elif raw_text_list is not None and raw_text_list[i] is not None:
|
||||||
|
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||||
# insert lang token if we're trained for it
|
# insert lang token if we're trained for it
|
||||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||||
|
@ -1070,6 +1085,8 @@ class Base(nn.Module):
|
||||||
# insert lang token if we're trained for it
|
# insert lang token if we're trained for it
|
||||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||||
|
if self.rvq_l_emb is not None and not self.interleave:
|
||||||
|
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None and text_list[i] is not None:
|
if text_list is not None and text_list[i] is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
@ -1084,6 +1101,8 @@ class Base(nn.Module):
|
||||||
# insert lang token if we're trained for it
|
# insert lang token if we're trained for it
|
||||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||||
|
if self.rvq_l_emb is not None and not self.interleave:
|
||||||
|
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if raw_text_list is not None and raw_text_list[i] is not None:
|
if raw_text_list is not None and raw_text_list[i] is not None:
|
||||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||||
|
@ -1197,7 +1216,7 @@ class Base(nn.Module):
|
||||||
embedding = self.text_emb( input )
|
embedding = self.text_emb( input )
|
||||||
|
|
||||||
device = embedding.device
|
device = embedding.device
|
||||||
elif name == "raw_text":
|
elif name == "raw_text" and self.raw_text_emb is not None:
|
||||||
embedding = self.raw_text_emb( input )
|
embedding = self.raw_text_emb( input )
|
||||||
|
|
||||||
device = embedding.device
|
device = embedding.device
|
||||||
|
@ -1643,6 +1662,10 @@ class Base(nn.Module):
|
||||||
if quant_levels is None:
|
if quant_levels is None:
|
||||||
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
|
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
|
||||||
|
|
||||||
|
# inputs don't have quant levels added, pure AR
|
||||||
|
if len(quant_levels) != len(inputs):
|
||||||
|
quant_levels = [ 0 for _ in range(len(inputs)) ]
|
||||||
|
|
||||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
|
|
||||||
x, mask = list_to_tensor(x_list)
|
x, mask = list_to_tensor(x_list)
|
||||||
|
@ -1652,10 +1675,6 @@ class Base(nn.Module):
|
||||||
device = x.device
|
device = x.device
|
||||||
batch_size = len(x_list)
|
batch_size = len(x_list)
|
||||||
|
|
||||||
# pure AR
|
|
||||||
if quant_levels is None:
|
|
||||||
quant_levels = [ 0 for _ in range(batch_size) ]
|
|
||||||
|
|
||||||
# we only need hidden states if we're training with layerskip
|
# we only need hidden states if we're training with layerskip
|
||||||
if self.layerskip and training:
|
if self.layerskip and training:
|
||||||
output_hidden_states = True
|
output_hidden_states = True
|
||||||
|
|
Loading…
Reference in New Issue
Block a user