experimental

This commit is contained in:
mrq 2025-01-05 19:05:00 -06:00
parent 2e6a7625e4
commit b445f4abb6
5 changed files with 176 additions and 73 deletions

View File

@ -248,11 +248,6 @@ class ModelExperimentalSettings:
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
# performs token dropout to compensate for errors
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
@ -270,11 +265,19 @@ class ModelExperimentalSettings:
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
# these technically should be as hyperparameters
# performs token dropout to compensate for errors
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
# these technically should be as hyperparameters
# classifier-free guidance training settings
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead
# failed experiment
layerskip: bool = False # layerskip compatible model (or training for)
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)

View File

@ -23,7 +23,7 @@ from .config import cfg, Config
from .models import get_models
from .models.lora import enable_lora
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, tokenize, sentence_split
from .data import get_phone_symmap, get_lang_symmap, tokenize, text_tokenize, sentence_split
from .models import download_model, DEFAULT_MODEL_PATH
if deepspeed_available:
@ -412,7 +412,7 @@ class TTS():
model = model_ar if model_ar is not None else model_nar
if model is not None:
text_list = model(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task],
disable_tqdm=not use_tqdm,
use_lora=use_lora,
**sampling_kwargs,
@ -423,6 +423,35 @@ class TTS():
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
return text_list[0]
elif task in ["phn", "un-phn"]:
lang = self.encode_lang( language )
lang = to_device(lang, device=self.device, dtype=torch.uint8)
with torch.autocast(self.device, dtype=dtype, enabled=amp):
model = model_ar if model_ar is not None else model_nar
if task == "phn":
text_list = None
raw_text_list = [ torch.tensor( text_tokenize( text ), device=self.device, dtype=torch.int16) ]
output_tokenizer = cfg.tokenizer
else:
text_list = [ torch.tensor( tokenize( text ), device=self.device, dtype=torch.int16) ]
raw_text_list = None
output_tokenizer = cfg.text_tokenizer
if model is not None:
text_list = model(
text_list=text_list, raw_text_list=raw_text_list, lang_list=[lang], task_list=[task],
disable_tqdm=not use_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
else:
raise Exception("!")
text_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
return text_list[0]
# stuff for rolling context
prefix_context = None

View File

@ -29,7 +29,7 @@ from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs
from .lora import enable_lora
from ..samplers import cfg_logits
text_task = [ "stt" ]
text_task = [ "stt", "phn", "un-phn" ]
class AR_NAR(Base):
# yikes
@ -40,23 +40,28 @@ class AR_NAR(Base):
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
def forward_train(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
):
# deduce batch_size
if text_list is not None:
default_task = "tts"
if text_list:
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list:
device = resps_list[0].device
batch_size = len(resps_list)
@ -74,6 +79,7 @@ class AR_NAR(Base):
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
use_raw_text_p = self.config.experimental.use_raw_text_p if self.config is not None else 0.0
# rate to train RVQ level AR-ly or NAR-ly
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else "random"
@ -161,10 +167,7 @@ class AR_NAR(Base):
# only apply stop token for RVQ level 0
if quant_level <= 0 and timesteps[i] is None:
# append stop tokens for AR
if task in text_task:
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
...
else:
if task not in text_task:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
if task == "len":
@ -174,6 +177,7 @@ class AR_NAR(Base):
if task not in text_task + ["len"]:
drop_text = False
drop_audio = False
swap_text = False
if random.random() < cfg_prom_dropout_p:
drop_audio = True
@ -181,6 +185,9 @@ class AR_NAR(Base):
if random.random() < cfg_cond_dropout_p:
drop_audio = True
drop_text = True
if random.random() < use_raw_text_p and raw_text_list[i] is not None:
swap_text = True
if drop_text:
text_list[i] = text_start_stop_sequence
@ -188,6 +195,9 @@ class AR_NAR(Base):
if drop_audio:
proms_list[i] = None
if swap_text and not drop_text:
text_list[i] = None
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
@ -209,14 +219,16 @@ class AR_NAR(Base):
def forward_nar_masked(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
@ -420,14 +432,17 @@ class AR_NAR(Base):
def forward_nar(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
@ -447,12 +462,16 @@ class AR_NAR(Base):
)
# deduce batch_size
if text_list is not None:
default_task = "tts"
if text_list:
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list:
device = resps_list[0].device
batch_size = len(resps_list)
@ -534,25 +553,31 @@ class AR_NAR(Base):
def forward_ar(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
task_list: list[Tensor],
text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
# deduce batch_size
if text_list is not None:
default_task = "tts"
if text_list:
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list:
device = resps_list[0].device
batch_size = len(resps_list)
@ -590,13 +615,17 @@ class AR_NAR(Base):
len_list = sequence_list
inputs = self.inputs(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
task_list=task_list,
raw_text_list=raw_text_list,
quant_levels=quant_levels,
)
@ -627,7 +656,6 @@ class AR_NAR(Base):
# convert tokens into int
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
# STT
start_slice = [ 0 for _ in range(batch_size) ]
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
@ -684,19 +712,29 @@ class AR_NAR(Base):
# get next in sequence
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
for n in iterator:
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
if batch_size == 1 and task_list[0] in ["phn", "un-phn"]:
text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ]
raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ]
else:
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
print( text_list, raw_text_list )
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
inputs = self.inputs(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
task_list=task_list,
raw_text_list=raw_text_list,
quant_levels=quant_levels,
)
@ -816,11 +854,12 @@ class AR_NAR(Base):
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
@ -833,19 +872,20 @@ class AR_NAR(Base):
**sampling_kwargs,
):
# deduce batch_size
if text_list is not None:
default_task = "tts"
# deduce batch_size
if text_list:
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list:
device = resps_list[0].device
batch_size = len(resps_list)
# generate task list if not provided
if task_list is None:
task_list = [ default_task for _ in range(batch_size) ]
# implicitly set for training
if training is None and text_list is not None and resps_list is not None:
n_levels_set = {r.shape[-1] for r in resps_list}
@ -856,10 +896,12 @@ class AR_NAR(Base):
# is training
if training:
return self.forward_train(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
@ -869,13 +911,17 @@ class AR_NAR(Base):
# is NAR
if (len_list is not None or resps_list is not None) and text_list is not None:
return self.forward_nar(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
@ -883,13 +929,17 @@ class AR_NAR(Base):
# is AR
return self.forward_ar(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,

View File

@ -51,6 +51,8 @@ special_tasks = [ "len", "stt", "phn", "un-phn" ]
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
task_outputs = {
"tts": "resp",
"ns": "resp",
"sr": "resp",
"stt": "text",
"len": "len",
"phn": "text",
@ -937,21 +939,32 @@ class Base(nn.Module):
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
task_list: list[str] | None = None,
time_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
quant_levels: int | list[int] | Tensor | None = None
):
device = text_list[0].device
batch_size = len(text_list)
if text_list and text_list[0] is not None:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list and raw_text_list[0] is not None:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list and proms_list[0] is not None:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list and resps_list[0] is not None:
device = resps_list[0].device
batch_size = len(resps_list)
inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size):
@ -973,6 +986,8 @@ class Base(nn.Module):
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
elif raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
@ -1022,6 +1037,8 @@ class Base(nn.Module):
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
elif raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
@ -1070,6 +1087,8 @@ class Base(nn.Module):
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None and not self.interleave:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
@ -1084,6 +1103,8 @@ class Base(nn.Module):
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None and not self.interleave:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt
if raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
@ -1197,7 +1218,7 @@ class Base(nn.Module):
embedding = self.text_emb( input )
device = embedding.device
elif name == "raw_text":
elif name == "raw_text" and self.raw_text_emb is not None:
embedding = self.raw_text_emb( input )
device = embedding.device
@ -1505,7 +1526,7 @@ class Base(nn.Module):
if self.config.loss_factors:
continue
# fill with ignored out tensor
token = torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
# perform loss calculation on the individual piece
if self.config.loss_factors:
@ -1643,6 +1664,10 @@ class Base(nn.Module):
if quant_levels is None:
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
# inputs don't have quant levels added, pure AR
if len(quant_levels) != len(inputs):
quant_levels = [ 0 for _ in range(len(inputs)) ]
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, mask = list_to_tensor(x_list)
@ -1652,10 +1677,6 @@ class Base(nn.Module):
device = x.device
batch_size = len(x_list)
# pure AR
if quant_levels is None:
quant_levels = [ 0 for _ in range(batch_size) ]
# we only need hidden states if we're training with layerskip
if self.layerskip and training:
output_hidden_states = True

View File

@ -27,6 +27,9 @@ _logger = logging.getLogger(__name__)
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch, teacher=None):
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
@ -106,9 +109,6 @@ def train_feeder(engine, batch, teacher=None):
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
return loss, stats
@torch.inference_mode()