experimental
This commit is contained in:
parent
2e6a7625e4
commit
b445f4abb6
|
@ -248,11 +248,6 @@ class ModelExperimentalSettings:
|
|||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
|
||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||
|
||||
# performs token dropout to compensate for errors
|
||||
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
||||
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
||||
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
||||
|
||||
causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions
|
||||
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
|
||||
|
@ -270,11 +265,19 @@ class ModelExperimentalSettings:
|
|||
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
|
||||
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
|
||||
|
||||
# these technically should be as hyperparameters
|
||||
# performs token dropout to compensate for errors
|
||||
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
||||
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
||||
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
||||
# these technically should be as hyperparameters
|
||||
# classifier-free guidance training settings
|
||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
||||
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
|
||||
|
||||
use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead
|
||||
|
||||
# failed experiment
|
||||
layerskip: bool = False # layerskip compatible model (or training for)
|
||||
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
||||
|
|
|
@ -23,7 +23,7 @@ from .config import cfg, Config
|
|||
from .models import get_models
|
||||
from .models.lora import enable_lora
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, get_lang_symmap, tokenize, sentence_split
|
||||
from .data import get_phone_symmap, get_lang_symmap, tokenize, text_tokenize, sentence_split
|
||||
from .models import download_model, DEFAULT_MODEL_PATH
|
||||
|
||||
if deepspeed_available:
|
||||
|
@ -412,7 +412,7 @@ class TTS():
|
|||
model = model_ar if model_ar is not None else model_nar
|
||||
if model is not None:
|
||||
text_list = model(
|
||||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
|
||||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task],
|
||||
disable_tqdm=not use_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
|
@ -423,6 +423,35 @@ class TTS():
|
|||
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
||||
|
||||
return text_list[0]
|
||||
elif task in ["phn", "un-phn"]:
|
||||
lang = self.encode_lang( language )
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||
model = model_ar if model_ar is not None else model_nar
|
||||
if task == "phn":
|
||||
text_list = None
|
||||
raw_text_list = [ torch.tensor( text_tokenize( text ), device=self.device, dtype=torch.int16) ]
|
||||
output_tokenizer = cfg.tokenizer
|
||||
else:
|
||||
text_list = [ torch.tensor( tokenize( text ), device=self.device, dtype=torch.int16) ]
|
||||
raw_text_list = None
|
||||
output_tokenizer = cfg.text_tokenizer
|
||||
|
||||
if model is not None:
|
||||
text_list = model(
|
||||
text_list=text_list, raw_text_list=raw_text_list, lang_list=[lang], task_list=[task],
|
||||
disable_tqdm=not use_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
||||
text_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
||||
|
||||
return text_list[0]
|
||||
|
||||
|
||||
# stuff for rolling context
|
||||
prefix_context = None
|
||||
|
|
|
@ -29,7 +29,7 @@ from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs
|
|||
from .lora import enable_lora
|
||||
from ..samplers import cfg_logits
|
||||
|
||||
text_task = [ "stt" ]
|
||||
text_task = [ "stt", "phn", "un-phn" ]
|
||||
|
||||
class AR_NAR(Base):
|
||||
# yikes
|
||||
|
@ -40,23 +40,28 @@ class AR_NAR(Base):
|
|||
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
|
||||
def forward_train(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list is not None:
|
||||
default_task = "tts"
|
||||
if text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
else:
|
||||
default_task = "stt"
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
elif resps_list:
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
|
@ -74,6 +79,7 @@ class AR_NAR(Base):
|
|||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
||||
use_raw_text_p = self.config.experimental.use_raw_text_p if self.config is not None else 0.0
|
||||
# rate to train RVQ level AR-ly or NAR-ly
|
||||
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
|
||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else "random"
|
||||
|
@ -161,10 +167,7 @@ class AR_NAR(Base):
|
|||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0 and timesteps[i] is None:
|
||||
# append stop tokens for AR
|
||||
if task in text_task:
|
||||
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
||||
...
|
||||
else:
|
||||
if task not in text_task:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
|
||||
if task == "len":
|
||||
|
@ -174,6 +177,7 @@ class AR_NAR(Base):
|
|||
if task not in text_task + ["len"]:
|
||||
drop_text = False
|
||||
drop_audio = False
|
||||
swap_text = False
|
||||
|
||||
if random.random() < cfg_prom_dropout_p:
|
||||
drop_audio = True
|
||||
|
@ -181,6 +185,9 @@ class AR_NAR(Base):
|
|||
if random.random() < cfg_cond_dropout_p:
|
||||
drop_audio = True
|
||||
drop_text = True
|
||||
|
||||
if random.random() < use_raw_text_p and raw_text_list[i] is not None:
|
||||
swap_text = True
|
||||
|
||||
if drop_text:
|
||||
text_list[i] = text_start_stop_sequence
|
||||
|
@ -188,6 +195,9 @@ class AR_NAR(Base):
|
|||
if drop_audio:
|
||||
proms_list[i] = None
|
||||
|
||||
if swap_text and not drop_text:
|
||||
text_list[i] = None
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
|
@ -209,14 +219,16 @@ class AR_NAR(Base):
|
|||
def forward_nar_masked(
|
||||
self,
|
||||
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
task_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
|
@ -420,14 +432,17 @@ class AR_NAR(Base):
|
|||
|
||||
def forward_nar(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
task_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
|
@ -447,12 +462,16 @@ class AR_NAR(Base):
|
|||
)
|
||||
|
||||
# deduce batch_size
|
||||
if text_list is not None:
|
||||
default_task = "tts"
|
||||
if text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
else:
|
||||
default_task = "stt"
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
elif resps_list:
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
|
@ -534,25 +553,31 @@ class AR_NAR(Base):
|
|||
def forward_ar(
|
||||
self,
|
||||
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
task_list: list[Tensor],
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
task_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list is not None:
|
||||
default_task = "tts"
|
||||
if text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
else:
|
||||
default_task = "stt"
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
elif resps_list:
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
|
@ -590,13 +615,17 @@ class AR_NAR(Base):
|
|||
len_list = sequence_list
|
||||
|
||||
inputs = self.inputs(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
task_list=task_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
|
@ -627,7 +656,6 @@ class AR_NAR(Base):
|
|||
# convert tokens into int
|
||||
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
|
||||
|
||||
# STT
|
||||
start_slice = [ 0 for _ in range(batch_size) ]
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
@ -684,19 +712,29 @@ class AR_NAR(Base):
|
|||
# get next in sequence
|
||||
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
|
||||
for n in iterator:
|
||||
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
if batch_size == 1 and task_list[0] in ["phn", "un-phn"]:
|
||||
text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ]
|
||||
raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ]
|
||||
else:
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
||||
print( text_list, raw_text_list )
|
||||
|
||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
|
||||
inputs = self.inputs(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
task_list=task_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
|
@ -816,11 +854,12 @@ class AR_NAR(Base):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
task_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
@ -833,19 +872,20 @@ class AR_NAR(Base):
|
|||
**sampling_kwargs,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list is not None:
|
||||
default_task = "tts"
|
||||
# deduce batch_size
|
||||
if text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
else:
|
||||
default_task = "stt"
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
elif resps_list:
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
# generate task list if not provided
|
||||
if task_list is None:
|
||||
task_list = [ default_task for _ in range(batch_size) ]
|
||||
|
||||
# implicitly set for training
|
||||
if training is None and text_list is not None and resps_list is not None:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
|
@ -856,10 +896,12 @@ class AR_NAR(Base):
|
|||
# is training
|
||||
if training:
|
||||
return self.forward_train(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
task_list=task_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
|
@ -869,13 +911,17 @@ class AR_NAR(Base):
|
|||
# is NAR
|
||||
if (len_list is not None or resps_list is not None) and text_list is not None:
|
||||
return self.forward_nar(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
task_list=task_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
|
@ -883,13 +929,17 @@ class AR_NAR(Base):
|
|||
|
||||
# is AR
|
||||
return self.forward_ar(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
task_list=task_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
|
|
|
@ -51,6 +51,8 @@ special_tasks = [ "len", "stt", "phn", "un-phn" ]
|
|||
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
|
||||
task_outputs = {
|
||||
"tts": "resp",
|
||||
"ns": "resp",
|
||||
"sr": "resp",
|
||||
"stt": "text",
|
||||
"len": "len",
|
||||
"phn": "text",
|
||||
|
@ -937,21 +939,32 @@ class Base(nn.Module):
|
|||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||
def inputs(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
text_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
task_list: list[str] | None = None,
|
||||
time_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: int | list[int] | Tensor | None = None
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
if text_list and text_list[0] is not None:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif raw_text_list and raw_text_list[0] is not None:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list and proms_list[0] is not None:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
elif resps_list and resps_list[0] is not None:
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
inputs = [ [] for _ in range(batch_size) ]
|
||||
for i in range(batch_size):
|
||||
|
@ -973,6 +986,8 @@ class Base(nn.Module):
|
|||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
elif raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
|
@ -1022,6 +1037,8 @@ class Base(nn.Module):
|
|||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
elif raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
|
@ -1070,6 +1087,8 @@ class Base(nn.Module):
|
|||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
if self.rvq_l_emb is not None and not self.interleave:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
@ -1084,6 +1103,8 @@ class Base(nn.Module):
|
|||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
if self.rvq_l_emb is not None and not self.interleave:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the text prompt
|
||||
if raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
|
@ -1197,7 +1218,7 @@ class Base(nn.Module):
|
|||
embedding = self.text_emb( input )
|
||||
|
||||
device = embedding.device
|
||||
elif name == "raw_text":
|
||||
elif name == "raw_text" and self.raw_text_emb is not None:
|
||||
embedding = self.raw_text_emb( input )
|
||||
|
||||
device = embedding.device
|
||||
|
@ -1505,7 +1526,7 @@ class Base(nn.Module):
|
|||
if self.config.loss_factors:
|
||||
continue
|
||||
# fill with ignored out tensor
|
||||
token = torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
|
||||
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
|
||||
|
||||
# perform loss calculation on the individual piece
|
||||
if self.config.loss_factors:
|
||||
|
@ -1643,6 +1664,10 @@ class Base(nn.Module):
|
|||
if quant_levels is None:
|
||||
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
|
||||
|
||||
# inputs don't have quant levels added, pure AR
|
||||
if len(quant_levels) != len(inputs):
|
||||
quant_levels = [ 0 for _ in range(len(inputs)) ]
|
||||
|
||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||
|
||||
x, mask = list_to_tensor(x_list)
|
||||
|
@ -1652,10 +1677,6 @@ class Base(nn.Module):
|
|||
device = x.device
|
||||
batch_size = len(x_list)
|
||||
|
||||
# pure AR
|
||||
if quant_levels is None:
|
||||
quant_levels = [ 0 for _ in range(batch_size) ]
|
||||
|
||||
# we only need hidden states if we're training with layerskip
|
||||
if self.layerskip and training:
|
||||
output_hidden_states = True
|
||||
|
|
|
@ -27,6 +27,9 @@ _logger = logging.getLogger(__name__)
|
|||
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
||||
|
||||
def train_feeder(engine, batch, teacher=None):
|
||||
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
||||
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
batch_size = len(batch["text"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
@ -106,9 +109,6 @@ def train_feeder(engine, batch, teacher=None):
|
|||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= {k: v.item() for k, v in stat.items()}
|
||||
|
||||
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
||||
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
|
||||
return loss, stats
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
Loading…
Reference in New Issue
Block a user