redid loss calculation to be cleaner, and position ID generation, and other things (I might need to train the NAR-len from scratch and not resume from an existing checkpoint.........)

This commit is contained in:
mrq 2024-11-14 22:17:47 -06:00
parent ef05c951ff
commit 39096f8ff3
8 changed files with 261 additions and 265 deletions

View File

@ -67,6 +67,8 @@ I'm uncertain on how to remedy this, as my options are:
* train an AR, and train a NAR, if the codec itself is at fault
* use an SSM like Mamba, if transformers entirely cannot model the codec
* train a separate model that simply converts from EnCodec to DAC
* train *all* NAR levels as independent masking sequences.
* if this works, then it means that there's little to no mappable relation between DAC's RVQ levels
## `transcribe.py`

View File

@ -48,17 +48,11 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
The reference model provided has *some* NAR demasking (mock diffusion) aware training to faciliate a pure NAR model, but:
* Sampling absolutely requires rep pen, or the output degenerates.
* Output isn't so great, as there's artifacting from either an underbaked model or a naive sampler.
* The former problem is easily "solved" by training a `len` classification task.
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible (but can be done in multiple steps).
* diffusion solves this, but requires a different underliny model architecture
* masking to emulate diffusion noising is best working solution, but has a lot of training challenges.
* existing solutions like Muse (text to image) and MaskGCT (text to speech) do this
To-do: fill out this more when it works. Getting this to work is a huge pain.
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
@ -67,22 +61,28 @@ To-do: fill out this more when it works. Getting this to work is a huge pain.
* MaskGCT does it both pre and post
* the test trainier actually degrades the output immensely when doing this
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do.
* This helps the base AR+NAR tasks and provides CFG sampling for such tasks anyways.
* ***Extreme*** care is required.
## Embeddings
## Embeddings (and Classifiers)
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence.
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can get away with treating each sequence as separate ranges within a total token sequence.
While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed.
With attention-based transformers, most embeddings can serve as a token itself and have the attention mechanism attend to it. Theoretically, there should be little to no functional differences between "tokenizing" an embedding, and summing a modifying embedding, but experimentation is needed for this assertion.
### Classifiers
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
Out of paranoia, each head is split for each macro-task (RVQ level) and an auxiliary head for tasks `stt` and `len`, even though the core half of the model's training was with a single output head.
### Text Embeddings
The input text phonemes (or output for STT) are passed through an embedding head (`text`), similar to how a normal text LLM would. Nothing fancy is required, as it's very straightforward.
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes.
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" text embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes, as language-specific pure text embeddings can be trained.
These embeddings *could* instead be added on top of the input prompt embedding instead of serving as additional tasks (similar to injecting position embeddings), but additional experimentation is required to see if the model both can work under this and/or benefits from this.
@ -115,16 +115,15 @@ As EnCodec encodes audio across eight codebooks (and DAC's 44Khz audio under nin
For the `prom` embedding, we can simply use each embedding for each layer. Each embedding level maps to its respective RVQ level.
Howver, the `resp` requires some extra care, as the model needs to both causally (AR) and parallel-ly (NAR) decode tokens.
* The first embedding level pertains to RVQ level 0 for the AR.
* The remaining embedding levels maps to RVQ level 0 + n for the NAR.
* The first embedding level pertains to RVQ level 0 for the AR (`AR:0:0`).
* This embedding predicts tokens within its own embedding.
* The remaining embedding levels maps to RVQ level 0 + n for the NAR (`NAR:L-1:L`).
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
* I believe this is because the model needs to "know" whether to predict ~~the next token in the sequence, or the token in the same position of the next RVQ level~~ which tokens of a given embedding.
* In other words, the AR's RVQ level 0 embedding predicts itself, while the NAR's embeddings predict the next level's embeddings.
* This is evident on how RVQ level 0 can be trained causally and in parallel with its own embeddings, rather than having limiting issues when reusing the embedding across the two.
* Unfortunately, providing a token for the current/target RVQ level within the input sequence doesn't seem to help? I don't remember if I experimented with this or not, but testing of a "sane" `resp` embedding proved to be unfruitful.
* I believe this is required because the model encodes which task to perform (rather than the attention heads), and which tokens to predict (rather than the classifiers)
* In other words, each embedding needs to be separated based on what tokens they do predict.
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.
* I have yet to conduct tests with interchanging the `prom` and `resp`, and the model definitely exhibits being able to map from the `prom` directly, and being able to inference from the `prom` being prefixed in the `resp`.
* The `text` embedding's robustness not only for reusing between each RVQ level, but as STT task as well is a mystery.
Finally, the model *may* then sum each embedding level back down to one sequence, as defined under `cfg.model.experimental.audio_embedding_sums`.
* The resulant sum is not normalized by the length.

View File

@ -27,8 +27,20 @@ Some additional flags can be passed as well:
A training paradigm that works for me is:
* setting the dataloader to sort by duration, then training one epoch, so the model starts with small utterances then trains to larger ones.
* the daring can wait until coherent speech emerges, then move to the next step
* some additional training using a shuffled dataloader, as the model will be fixated towards whatever duration range it was trained under.
* additional training for sampling per speaker, to better help diversify how well it can perform for a range of speakers, rather than just speaking itself
* I don't think this is crucial, but speaker-based sampling seems to be a huge placebo if anything.
I don't remember the exact numbers off the top of my head, but a good loss/accuracy/gradient norm to look out for when coherent speech emergies are:
* loss <3.0
* acc >0.7
* grad_norm <0.2
Training under `float16` should be fairly simple, but care is required to keep the loss scaling factor above 8K, and probably even 16K.
* At the very least for pre-trained models, low enough loss scales will irreparably fry the model, and no amount of training afterwards seems to "fix" it.
* The current DeepSpeed configuration should keep the loss scale capped to 32K, but this so far is only validated for pre-trained models.
* Training under `bfloat16` does not have to worry about this as there's no need for loss scaling, but I feel the model performs better when trained under `float16`+AMP rather than `bfloat16` (with or without AMP).
## Try Me

View File

@ -260,7 +260,9 @@ class ModelExperimentalSettings:
masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
masking_separate_embeddings: bool = False # to-do: explain
masking_ratio_fixed: bool = False
ignore_inputs_for_loss: bool = False
# classifier-free guidance shit
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
@ -301,7 +303,7 @@ class Model:
return [ self ] if not name or self.name == name else []
def loss_factor(self, k):
return self.loss_factors[k] if k in self.loss_factors else 1.0
return self.loss_factors.get(k, 0.0)
@property
def max_levels(self):
@ -508,6 +510,9 @@ class DeepSpeed:
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
loss_scale_window: int = 100
min_loss_scale: float = 8192.0
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
@cached_property
@ -558,8 +563,8 @@ class DeepSpeed:
"fp16": {
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": True, # ???
"loss_scale_window": 100, # raise every 100 consecutive good steps
"min_loss_scale": 32768.0, # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy
"loss_scale_window": self.loss_scale_window, # raise every 100 consecutive good steps
"min_loss_scale": self.min_loss_scale, # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
},
"bf16": {

View File

@ -1761,12 +1761,12 @@ if __name__ == "__main__":
elif args.action == "metadata":
create_dataset_metadata()
elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
train_dl, val_dl = create_train_val_dataloader()
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
@ -1784,8 +1784,8 @@ if __name__ == "__main__":
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
#v[i]['proms'][j] = v[i]['proms'][j].shape
#v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items():
for i in range(len(v)):

View File

@ -269,7 +269,6 @@ class TTS():
)
elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that
kwargs = {}
# nasty hardcode to load a reference file and have that as the input target
if load_from_artifact and load_from_artifact.exists():

View File

@ -148,7 +148,7 @@ class AR_NAR(Base):
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0 and timesteps[i] is not None:
if quant_level <= 0 and timesteps[i] is None:
# append stop tokens for AR
if task in text_task:
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
@ -254,9 +254,8 @@ class AR_NAR(Base):
null_prom = [ None for _ in range(batch_size) ]
prev_list = resps_list
for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps):
#for noise_p, annealed_temperature, temperature, cfg_strength in zip( manual_ratios, manual_temp, manual_samp_temp, manual_cfg ):
annealing = (steps_until_x0 / max_steps)
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
annealing = 1.0 - timestep
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# pick the worst scoring tokens to mask off
@ -268,6 +267,13 @@ class AR_NAR(Base):
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
"""
sampling_temperature = temperature * annealing
sampling_cfg = cfg_strength * timestep
"""
sampling_temperature = temperature
sampling_cfg = cfg_strength
# setup inputs
inputs = super().inputs(
text_list=text_list,
@ -302,7 +308,7 @@ class AR_NAR(Base):
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * (cfg_strength * timestep)
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * sampling_cfg
# sample with sampler settings
filtered_sampled = super().sample(
@ -310,7 +316,7 @@ class AR_NAR(Base):
prev_list=prev_list,
quant_levels=quant_levels,
temperature=temperature * annealing,
temperature=sampling_temperature,
**sampling_kwargs,
)
@ -884,7 +890,7 @@ def example_usage():
import numpy as np
import re
cfg.model.experimental.masking_train_p = 0.5
# cfg.model.experimental.masking_train_p = 0.5
cfg.hyperparameters.batch_size = 1
cfg.hyperparameters.gradient_accumulation_steps = 1
@ -903,7 +909,7 @@ def example_usage():
text_list = [ text ] * batch_size
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
resps_list = [ audio ] * batch_size
resps_list = [ audio[:cfg.dataset.frames_per_second * 4, :] ] * batch_size
kwargs = {
'n_text_tokens': 256,
@ -922,10 +928,10 @@ def example_usage():
}
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
available_tasks = ["tts-ar", "tts-nar"]
available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
model = AR_NAR(**kwargs).to(cfg.device)
steps = 500 // batch_size
steps = 1000 // batch_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
@ -1035,7 +1041,7 @@ def example_usage():
if task == "stt":
prom = [ task ]
else:
task = "tts" if random.random() > 0.1 else "len"
task = "tts" if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
texts.append( text )
proms.append( prom )
@ -1053,7 +1059,7 @@ def example_usage():
if task == "tts-nar":
len_list = engine(text_list, proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
len_list = [ resp_list[0].shape[0] for l in len_list ]
resps_list = engine( text_list, proms_list, len_list=len_list, temperature=0.0 )
resps_list = engine( text_list, proms_list, len_list=len_list )
else:
resps_list = engine( text_list, proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
resps_list = engine( text_list, proms_list, resps_list=resps_list, temperature=0.0 )

View File

@ -48,6 +48,12 @@ from ..utils.pattern import DelayedPatternProvider, VALLEPattern
summed_embeddings_task = [ "stt" ]
special_tasks = [ "len", "stt" ]
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
task_outputs = {
"tts": "resp",
"stt": "text",
"len": "len",
}
def _dropout_mask( input, p=None ):
# cosine scheduling
@ -200,13 +206,14 @@ class AudioEmbedding(nn.Module):
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
if sums is None:
sums = self.sums
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
# handle mapping from name
if name in self.names:
offset = self.names.index( name )
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
offset -= quant_level # offset by quant level since it'll iterate up that many levels
if self.sums and quant_level > 0:
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
@ -322,7 +329,7 @@ class Base(nn.Module):
def loss_factor(self, k):
if self.config is None:
return 1.0
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
return self.config.loss_factor(k)
def _prune(self, l: Tensor, stop = None):
if stop is None:
@ -429,13 +436,14 @@ class Base(nn.Module):
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False
masking_ratio_fixed = self.config.experimental.masking_ratio_fixed if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
layerskip = self.config.experimental.layerskip if self.config is not None else False
layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
masking_separate_embeddings = self.config.experimental.masking_separate_embeddings if self.config is not None else False
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
n_tones = self.config.tones if self.config is not None else 1
@ -446,15 +454,15 @@ class Base(nn.Module):
l_tokens = [n_resp_tokens] * self.n_resp_levels
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
# NAR-len model
elif "len" in self.capabilities and masking_separate_embeddings:
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
if masking_separate_embeddings:
l_tokens += [n_resp_tokens]
resp_l_names += ['NAR:0:0']
if "ar" in self.capabilities:
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
else:
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
resp_l_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
# +1 to include the stop or mask token
@ -462,13 +470,19 @@ class Base(nn.Module):
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
classifier_l_names = resp_l_names + ["stt"]
classifier_l_tokens = l_tokens + [ n_text_tokens ]
classifier_l_names = resp_l_names + [ "len" ]
if "len" in self.capabilities and False:
classifier_l_tokens += [ n_text_tokens ]
classifier_l_names += ["len"]
self.unified_position_ids = unified_position_ids
self.interleave = interleave
self.layerskip = layerskip
self.inject_timestep_embedding = False # results in bad output
self.masking_separate_embeddings = masking_separate_embeddings
self.masking_ratio_fixed = masking_ratio_fixed
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -805,10 +819,10 @@ class Base(nn.Module):
self.metrics = None
else:
self.classifier = None
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model, l_names=classifier_l_names )
self.classifiers = Classifiers( classifier_l_tokens, d_model, l_names=classifier_l_names )
self.accuracy_metric = None
self.precision_metric = None
self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
self.metrics = Metrics( classifier_l_tokens )
"""
if tie_classifier_to_embedding:
@ -996,24 +1010,22 @@ class Base(nn.Module):
inputs[i].append( ( "tone", tone_list[i] ) )
# insert timestep token
if timestep is not None:
# store timestep information
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
# force set to use this classifier level
classifier_level = "NAR:0:0" if self.masking_separate_embeddings else "AR:0:0"
classifier_level = "NAR:0:0"
# a paper said to use a fixed masking ratio for training
p = 0.8
# store timestep information
if not self.masking_ratio_fixed:
# cosine scheduled timestep => masking ratio
p = math.cos(timestep * math.pi * 0.5)
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if self.training:
dropout_mask = _dropout_mask( resps_list[i], p )
inputs[i].append( ("dropout_mask", dropout_mask ) )
# insert the current output response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if timestep is not None and self.training:
"""
# a paper said to use a fixed masking ratio for training
p = 0.8
"""
# cosine scheduled timestep => masking ratio
p = math.cos(timestep * math.pi * 0.5)
dropout_mask = _dropout_mask( resps_list[i], p )
inputs[i].append( ("dropout_mask", dropout_mask ) )
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
@ -1047,7 +1059,7 @@ class Base(nn.Module):
# yes this could be encoded better
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
inputs[i].append( ("classifier_level", "stt") )
inputs[i].append( ("classifier_level", "len") )
# Speech-to-Text prediction task
# Sequence: <resp><sep><rvq lvl><sep><text>
elif task_type == "stt":
@ -1125,6 +1137,7 @@ class Base(nn.Module):
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
embedding = None
# is already an embedding
if name == "task":
# noop
@ -1165,7 +1178,6 @@ class Base(nn.Module):
embedding = self.resps_emb(
# if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
#quant_level = 0,
name = classifier_level,
)
@ -1174,7 +1186,6 @@ class Base(nn.Module):
embedding = self.resps_emb(
# if masked use masked token, else original token
input if input.dim() == 1 else input[:, 0],
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
#quant_level = 0,
name = classifier_level,
)
@ -1212,7 +1223,8 @@ class Base(nn.Module):
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
offset = 0 if classifier_level.startswith("AR:") else 1,
#offset = 0 if classifier_level.startswith("AR:") else 1,
name = classifier_level,
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
)
@ -1277,29 +1289,33 @@ class Base(nn.Module):
# there's a better way
if not self.unified_position_ids:
x_list = []
non_tokens = ["task", "dropout_mask", "classifier_level"]
last_input = ["resp", "len"]
def get_input_token_length( name, input ):
def get_input_token_length( name, input, task ):
# task token
if isinstance(input, str):
return 1
# list of tokens
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] )
# interleaved model
if self.interleave and name == "resp":
return input.shape[0] * input.shape[1]
# ending input will not have a separator later
return input.shape[0] + (0 if name in last_input else 1)
return input.shape[0]
for batch_index, batch_input in enumerate(inputs):
# pre-iterate
task = "tts"
for name, input in batch_input:
if name == "task":
task = input
batch = torch.cat( [
torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
for name, input in batch_input if name not in non_tokens
torch.tensor([*range(get_input_token_length(name, input, task) + (1 if name != task_outputs.get(task, name) else 0))], device=device, dtype=torch.int32)
for name, input in batch_input if name not in non_tokened_names
] )
delta = ids[batch_index].shape[0] - batch.shape[0]
@ -1319,8 +1335,8 @@ class Base(nn.Module):
quant_levels: list[int] | None = None,
):
loss = dict(ce = dict())
stats = dict(acc = dict())
loss = {}
stats = {}
device = logits[0].device
batch_size = len(logits)
@ -1328,6 +1344,12 @@ class Base(nn.Module):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
"""
if isinstance(input, str):
return torch.tensor( [ self.ignore_index ], device=device, dtype=torch.int16)
return torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
"""
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
@ -1336,192 +1358,143 @@ class Base(nn.Module):
return torch.full_like(input[..., 0], self.ignore_index)
return input if input.dim() == 1 else input[:, quant_level]
# old, "naive" way, no loss factoring
if not self.config.loss_factors:
target_list = []
task_list = []
is_causal = []
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
target = []
task_type = "tts"
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
dropout_mask = None
for name, input in batch:
if name == "dropout_mask":
dropout_mask = input
for name, input in batch:
if name == "task":
task_type = input
task_list.append( input )
if task_type in special_tasks:
causal = True
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
elif name == "resp":
# mask found, apply it
if dropout_mask is not None:
# if mask use original token, else ignore
causal = False
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
elif self.interleave:
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
elif task_type in summed_embeddings_task:
target.append( torch.full_like(input[..., 0], self.ignore_index) )
else:
target.append( input if input.dim() == 1 else input[:, quant_level] )
elif name == "timestep":
target.append( torch.tensor([self.ignore_index], device=input.device) )
elif name in ["text", "quant_level", "lang", "tone", "len"]:
target.append( input )
is_causal.append( causal )
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
batch_size = len(target_list)
# modify only causal sequences so it can properly behave like a transformer
for i in range(batch_size):
quant_level = quant_levels[i]
task_name = task_list[i]
causal = is_causal[i]
if causal:
l = self.causal_size
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
target_list[i] = target_list[i][..., l:] # predicts token n + 1
for batch_index, target in enumerate( target_list ):
logit = logits[batch_index]
max_classes = logit.shape[-1]
max_token = torch.max( target ).item()
if max_token > max_classes:
task = self.get_input(inputs, "task", at=batch_index)
print( batch_index, task, target, max_token, max_classes, inputs[batch_index] )
# see comments for the split-loss calc cross_entropy call
if False:
target = torch.cat( target_list )
inputs = torch.cat( logits )
loss = dict(
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
stats = self.metrics( inputs, targets, classifier_levels ) if self.metrics is not None else dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
else:
# nll being natural log likelihood :)))) (I don't know why this completely escaped me originally with thinking it meant something else)
loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
stats = self.metrics( logits, target_list, self.classifiers.indices( classifier_levels ) ) if self.metrics is not None else dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
)
return LossStats(loss, stats)
"""
# considerations:
# * split losses does not maintain the entire sequence
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
# + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
"""
# to-do: use NAR-len training and better causal-awareness
info = {}
batch_size = len( inputs )
for i, batch in enumerate( inputs ):
it = 0
quant_level = quant_levels[i]
task_name = None
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
target = []
causal = True
task_type = "tts"
dropout_mask = None
for name, input in batch:
if name == "dropout_mask":
dropout_mask = input
classifier_level = None
for name, input in batch:
# meta-input, no corresponding token at the moment
if name == "task":
task_name = input
if task_type in special_tasks:
causal = True
task_type = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
classifier_level = input
# autoregressive, causal
if classifier_level.startswith("AR:"):
causal = True
# nonautoregressive, parallel
elif classifier_level.startswith("NAR:"):
causal = False
it = 0
for name, input in batch:
token = None
ignored = False
# non-tokened tasks
if name in non_tokened_names:
continue
# do not use resp as-is
if name == "resp":
# prom can either be a tensor itself or a list of tensors and strings
if name == "prom":
# expand to list if not a list
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
elif name == "resp":
# mask found, apply it
if dropout_mask is not None:
# if mask use original token, else ignore
causal = False
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index )
# flatten
elif self.interleave:
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
elif task_type in summed_embeddings_task:
input = torch.full_like(input[..., 0], self.ignore_index)
token = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
# use resps as-is
else:
input = input if input.dim() == 1 else input[:, quant_level]
# select prom level
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
seq_len = input.shape[0]
logit = logits[i][it:it+seq_len]
it += seq_len + 1 # +1 to incorporate the separator
# for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if causal and seq_len > 1:
l = self.causal_size
logit = logit[..., :-l, :]
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
if name not in info:
info[name] = {
"targets": [],
"logits": [],
}
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
info[name]["targets"].append( input.long() )
info[name]["logits"].append( logit )
for name, batch in info.items():
loss_factor = self.loss_factor(name)
if name not in ["text", "prom", "resp", "len"]:
continue
if loss_factor == 0.0:
continue
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
# to-do: set this to a var
if False:
targets = torch.cat( batch["targets"] ).long()
inputs = torch.cat( batch["logits"] )
loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
stats["acc"][name] = self.accuracy_metric( inputs, targets )
# probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else:
loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
if self.metrics is not None:
metrics = self.metrics( batch["logits"], batch["targets"], self.classifiers.indices( classifier_levels ) )
stats["acc"][name] = metrics["acc"]
token = input if input.dim() == 1 else input[:, quant_level]
# not a special input, inject as-is
else:
stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
token = input
if not isinstance(token, torch.Tensor):
continue
if token.is_floating_point():
ignored = True
# grab range of our logits for later
seq_len = token.shape[0]
start, end = it, it+seq_len
it += seq_len + 1 # +1 to incorporate the separator
# deduce if a name for a task is an input or output
if self.ignore_inputs_for_loss and name != task_outputs.get(task_type, name):
ignored = True
if ignored:
# pruned
if self.config.loss_factors:
continue
# fill with ignored out tensor
token = torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
# perform loss calculation on the individual piece
if self.config.loss_factors:
loss_factor = self.loss_factor(name)
if loss_factor == 0.0:
continue
logit = logits[batch_index][start:end]
if causal and seq_len > 1:
l = self.causal_size
logit = logit[..., :-l, :]
token = token[..., l:] # shift sequence to the right by one (or causal chunk size)
if f'{name}.nll' not in loss:
loss[f'{name}.nll'] = []
if f'{name}.acc' not in stats:
stats[f'{name}.acc'] = []
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
metrics = self.accuracy_metric( logit, token )
loss[f'{name}.nll'].append( nll )
stats[f'{name}.acc'].append( metrics )
# add to list
else:
target.append( token )
# perofrm loss calculation on the entire sequence
if not self.config.loss_factors:
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
logit = logits[batch_index]
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
target = target[..., l:] # ...predicts token n + 1
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
else:
metrics = self.accuracy_metric( logit, target )
if 'nll' not in loss:
loss['nll'] = []
if 'acc' not in stats:
stats['acc'] = []
loss["nll"].append( nll )
stats["acc"].append( metrics )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats)