redid loss calculation to be cleaner, and position ID generation, and other things (I might need to train the NAR-len from scratch and not resume from an existing checkpoint.........)
This commit is contained in:
parent
ef05c951ff
commit
39096f8ff3
|
@ -67,6 +67,8 @@ I'm uncertain on how to remedy this, as my options are:
|
||||||
* train an AR, and train a NAR, if the codec itself is at fault
|
* train an AR, and train a NAR, if the codec itself is at fault
|
||||||
* use an SSM like Mamba, if transformers entirely cannot model the codec
|
* use an SSM like Mamba, if transformers entirely cannot model the codec
|
||||||
* train a separate model that simply converts from EnCodec to DAC
|
* train a separate model that simply converts from EnCodec to DAC
|
||||||
|
* train *all* NAR levels as independent masking sequences.
|
||||||
|
* if this works, then it means that there's little to no mappable relation between DAC's RVQ levels
|
||||||
|
|
||||||
## `transcribe.py`
|
## `transcribe.py`
|
||||||
|
|
||||||
|
|
|
@ -48,17 +48,11 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
|
||||||
The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
|
The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
|
||||||
|
|
||||||
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
||||||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
* The former problem is easily "solved" by training a `len` classification task.
|
||||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
|
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible (but can be done in multiple steps).
|
||||||
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
|
* diffusion solves this, but requires a different underliny model architecture
|
||||||
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
|
* masking to emulate diffusion noising is best working solution, but has a lot of training challenges.
|
||||||
|
* existing solutions like Muse (text to image) and MaskGCT (text to speech) do this
|
||||||
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
|
||||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
|
||||||
|
|
||||||
The reference model provided has *some* NAR demasking (mock diffusion) aware training to faciliate a pure NAR model, but:
|
|
||||||
* Sampling absolutely requires rep pen, or the output degenerates.
|
|
||||||
* Output isn't so great, as there's artifacting from either an underbaked model or a naive sampler.
|
|
||||||
|
|
||||||
To-do: fill out this more when it works. Getting this to work is a huge pain.
|
To-do: fill out this more when it works. Getting this to work is a huge pain.
|
||||||
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
|
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
|
||||||
|
@ -67,22 +61,28 @@ To-do: fill out this more when it works. Getting this to work is a huge pain.
|
||||||
* MaskGCT does it both pre and post
|
* MaskGCT does it both pre and post
|
||||||
* the test trainier actually degrades the output immensely when doing this
|
* the test trainier actually degrades the output immensely when doing this
|
||||||
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do.
|
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do.
|
||||||
* This helps the base AR+NAR tasks and provides CFG sampling for such tasks anyways.
|
* ***Extreme*** care is required.
|
||||||
|
|
||||||
## Embeddings
|
## Embeddings (and Classifiers)
|
||||||
|
|
||||||
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence.
|
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can get away with treating each sequence as separate ranges within a total token sequence.
|
||||||
|
|
||||||
While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed.
|
While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed.
|
||||||
|
|
||||||
With attention-based transformers, most embeddings can serve as a token itself and have the attention mechanism attend to it. Theoretically, there should be little to no functional differences between "tokenizing" an embedding, and summing a modifying embedding, but experimentation is needed for this assertion.
|
With attention-based transformers, most embeddings can serve as a token itself and have the attention mechanism attend to it. Theoretically, there should be little to no functional differences between "tokenizing" an embedding, and summing a modifying embedding, but experimentation is needed for this assertion.
|
||||||
|
|
||||||
|
### Classifiers
|
||||||
|
|
||||||
|
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
|
||||||
|
|
||||||
|
Out of paranoia, each head is split for each macro-task (RVQ level) and an auxiliary head for tasks `stt` and `len`, even though the core half of the model's training was with a single output head.
|
||||||
|
|
||||||
### Text Embeddings
|
### Text Embeddings
|
||||||
|
|
||||||
The input text phonemes (or output for STT) are passed through an embedding head (`text`), similar to how a normal text LLM would. Nothing fancy is required, as it's very straightforward.
|
The input text phonemes (or output for STT) are passed through an embedding head (`text`), similar to how a normal text LLM would. Nothing fancy is required, as it's very straightforward.
|
||||||
|
|
||||||
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
|
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" text embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
|
||||||
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes.
|
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes, as language-specific pure text embeddings can be trained.
|
||||||
|
|
||||||
These embeddings *could* instead be added on top of the input prompt embedding instead of serving as additional tasks (similar to injecting position embeddings), but additional experimentation is required to see if the model both can work under this and/or benefits from this.
|
These embeddings *could* instead be added on top of the input prompt embedding instead of serving as additional tasks (similar to injecting position embeddings), but additional experimentation is required to see if the model both can work under this and/or benefits from this.
|
||||||
|
|
||||||
|
@ -115,16 +115,15 @@ As EnCodec encodes audio across eight codebooks (and DAC's 44Khz audio under nin
|
||||||
For the `prom` embedding, we can simply use each embedding for each layer. Each embedding level maps to its respective RVQ level.
|
For the `prom` embedding, we can simply use each embedding for each layer. Each embedding level maps to its respective RVQ level.
|
||||||
|
|
||||||
Howver, the `resp` requires some extra care, as the model needs to both causally (AR) and parallel-ly (NAR) decode tokens.
|
Howver, the `resp` requires some extra care, as the model needs to both causally (AR) and parallel-ly (NAR) decode tokens.
|
||||||
* The first embedding level pertains to RVQ level 0 for the AR.
|
* The first embedding level pertains to RVQ level 0 for the AR (`AR:0:0`).
|
||||||
* The remaining embedding levels maps to RVQ level 0 + n for the NAR.
|
* This embedding predicts tokens within its own embedding.
|
||||||
|
* The remaining embedding levels maps to RVQ level 0 + n for the NAR (`NAR:L-1:L`).
|
||||||
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
|
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
|
||||||
* I believe this is because the model needs to "know" whether to predict ~~the next token in the sequence, or the token in the same position of the next RVQ level~~ which tokens of a given embedding.
|
* I believe this is required because the model encodes which task to perform (rather than the attention heads), and which tokens to predict (rather than the classifiers)
|
||||||
* In other words, the AR's RVQ level 0 embedding predicts itself, while the NAR's embeddings predict the next level's embeddings.
|
* In other words, each embedding needs to be separated based on what tokens they do predict.
|
||||||
* This is evident on how RVQ level 0 can be trained causally and in parallel with its own embeddings, rather than having limiting issues when reusing the embedding across the two.
|
|
||||||
* Unfortunately, providing a token for the current/target RVQ level within the input sequence doesn't seem to help? I don't remember if I experimented with this or not, but testing of a "sane" `resp` embedding proved to be unfruitful.
|
|
||||||
|
|
||||||
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.
|
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.
|
||||||
* I have yet to conduct tests with interchanging the `prom` and `resp`, and the model definitely exhibits being able to map from the `prom` directly, and being able to inference from the `prom` being prefixed in the `resp`.
|
* The `text` embedding's robustness not only for reusing between each RVQ level, but as STT task as well is a mystery.
|
||||||
|
|
||||||
Finally, the model *may* then sum each embedding level back down to one sequence, as defined under `cfg.model.experimental.audio_embedding_sums`.
|
Finally, the model *may* then sum each embedding level back down to one sequence, as defined under `cfg.model.experimental.audio_embedding_sums`.
|
||||||
* The resulant sum is not normalized by the length.
|
* The resulant sum is not normalized by the length.
|
||||||
|
|
|
@ -27,8 +27,20 @@ Some additional flags can be passed as well:
|
||||||
|
|
||||||
A training paradigm that works for me is:
|
A training paradigm that works for me is:
|
||||||
* setting the dataloader to sort by duration, then training one epoch, so the model starts with small utterances then trains to larger ones.
|
* setting the dataloader to sort by duration, then training one epoch, so the model starts with small utterances then trains to larger ones.
|
||||||
|
* the daring can wait until coherent speech emerges, then move to the next step
|
||||||
* some additional training using a shuffled dataloader, as the model will be fixated towards whatever duration range it was trained under.
|
* some additional training using a shuffled dataloader, as the model will be fixated towards whatever duration range it was trained under.
|
||||||
* additional training for sampling per speaker, to better help diversify how well it can perform for a range of speakers, rather than just speaking itself
|
* additional training for sampling per speaker, to better help diversify how well it can perform for a range of speakers, rather than just speaking itself
|
||||||
|
* I don't think this is crucial, but speaker-based sampling seems to be a huge placebo if anything.
|
||||||
|
|
||||||
|
I don't remember the exact numbers off the top of my head, but a good loss/accuracy/gradient norm to look out for when coherent speech emergies are:
|
||||||
|
* loss <3.0
|
||||||
|
* acc >0.7
|
||||||
|
* grad_norm <0.2
|
||||||
|
|
||||||
|
Training under `float16` should be fairly simple, but care is required to keep the loss scaling factor above 8K, and probably even 16K.
|
||||||
|
* At the very least for pre-trained models, low enough loss scales will irreparably fry the model, and no amount of training afterwards seems to "fix" it.
|
||||||
|
* The current DeepSpeed configuration should keep the loss scale capped to 32K, but this so far is only validated for pre-trained models.
|
||||||
|
* Training under `bfloat16` does not have to worry about this as there's no need for loss scaling, but I feel the model performs better when trained under `float16`+AMP rather than `bfloat16` (with or without AMP).
|
||||||
|
|
||||||
## Try Me
|
## Try Me
|
||||||
|
|
||||||
|
|
|
@ -260,7 +260,9 @@ class ModelExperimentalSettings:
|
||||||
|
|
||||||
masking_train_p: float = 0.0 # odds of training with masking
|
masking_train_p: float = 0.0 # odds of training with masking
|
||||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||||
masking_separate_embeddings: bool = False # to-do: explain
|
|
||||||
|
masking_ratio_fixed: bool = False
|
||||||
|
ignore_inputs_for_loss: bool = False
|
||||||
|
|
||||||
# classifier-free guidance shit
|
# classifier-free guidance shit
|
||||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||||
|
@ -301,7 +303,7 @@ class Model:
|
||||||
return [ self ] if not name or self.name == name else []
|
return [ self ] if not name or self.name == name else []
|
||||||
|
|
||||||
def loss_factor(self, k):
|
def loss_factor(self, k):
|
||||||
return self.loss_factors[k] if k in self.loss_factors else 1.0
|
return self.loss_factors.get(k, 0.0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_levels(self):
|
def max_levels(self):
|
||||||
|
@ -508,6 +510,9 @@ class DeepSpeed:
|
||||||
|
|
||||||
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
||||||
|
|
||||||
|
loss_scale_window: int = 100
|
||||||
|
min_loss_scale: float = 8192.0
|
||||||
|
|
||||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
@ -558,8 +563,8 @@ class DeepSpeed:
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
||||||
"auto_cast": True, # ???
|
"auto_cast": True, # ???
|
||||||
"loss_scale_window": 100, # raise every 100 consecutive good steps
|
"loss_scale_window": self.loss_scale_window, # raise every 100 consecutive good steps
|
||||||
"min_loss_scale": 32768.0, # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy
|
"min_loss_scale": self.min_loss_scale, # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy
|
||||||
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
|
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
|
||||||
},
|
},
|
||||||
"bf16": {
|
"bf16": {
|
||||||
|
|
|
@ -1761,12 +1761,12 @@ if __name__ == "__main__":
|
||||||
elif args.action == "metadata":
|
elif args.action == "metadata":
|
||||||
create_dataset_metadata()
|
create_dataset_metadata()
|
||||||
elif args.action == "sample":
|
elif args.action == "sample":
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
samples = {
|
samples = {
|
||||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||||
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||||
}
|
}
|
||||||
|
|
||||||
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -1784,8 +1784,8 @@ if __name__ == "__main__":
|
||||||
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
|
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
|
||||||
v[i]['proms'][j] = v[i]['proms'][j].shape
|
#v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||||
v[i]['resps'][j] = v[i]['resps'][j].shape
|
#v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||||
|
|
||||||
for k, v in samples.items():
|
for k, v in samples.items():
|
||||||
for i in range(len(v)):
|
for i in range(len(v)):
|
||||||
|
|
|
@ -269,7 +269,6 @@ class TTS():
|
||||||
)
|
)
|
||||||
elif model_len is not None:
|
elif model_len is not None:
|
||||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that
|
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
# nasty hardcode to load a reference file and have that as the input target
|
# nasty hardcode to load a reference file and have that as the input target
|
||||||
if load_from_artifact and load_from_artifact.exists():
|
if load_from_artifact and load_from_artifact.exists():
|
||||||
|
|
|
@ -148,7 +148,7 @@ class AR_NAR(Base):
|
||||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||||
|
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if quant_level <= 0 and timesteps[i] is not None:
|
if quant_level <= 0 and timesteps[i] is None:
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
if task in text_task:
|
if task in text_task:
|
||||||
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
||||||
|
@ -254,9 +254,8 @@ class AR_NAR(Base):
|
||||||
null_prom = [ None for _ in range(batch_size) ]
|
null_prom = [ None for _ in range(batch_size) ]
|
||||||
prev_list = resps_list
|
prev_list = resps_list
|
||||||
|
|
||||||
for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps):
|
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
||||||
#for noise_p, annealed_temperature, temperature, cfg_strength in zip( manual_ratios, manual_temp, manual_samp_temp, manual_cfg ):
|
annealing = 1.0 - timestep
|
||||||
annealing = (steps_until_x0 / max_steps)
|
|
||||||
# get noise level, per cosine scheduling
|
# get noise level, per cosine scheduling
|
||||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||||
# pick the worst scoring tokens to mask off
|
# pick the worst scoring tokens to mask off
|
||||||
|
@ -268,6 +267,13 @@ class AR_NAR(Base):
|
||||||
# timestep inputs
|
# timestep inputs
|
||||||
time_list = [ timestep for _ in range(batch_size) ]
|
time_list = [ timestep for _ in range(batch_size) ]
|
||||||
|
|
||||||
|
"""
|
||||||
|
sampling_temperature = temperature * annealing
|
||||||
|
sampling_cfg = cfg_strength * timestep
|
||||||
|
"""
|
||||||
|
sampling_temperature = temperature
|
||||||
|
sampling_cfg = cfg_strength
|
||||||
|
|
||||||
# setup inputs
|
# setup inputs
|
||||||
inputs = super().inputs(
|
inputs = super().inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -302,7 +308,7 @@ class AR_NAR(Base):
|
||||||
#layer_skip_variables=sampling_layer_skip_variables,
|
#layer_skip_variables=sampling_layer_skip_variables,
|
||||||
)
|
)
|
||||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * (cfg_strength * timestep)
|
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * sampling_cfg
|
||||||
|
|
||||||
# sample with sampler settings
|
# sample with sampler settings
|
||||||
filtered_sampled = super().sample(
|
filtered_sampled = super().sample(
|
||||||
|
@ -310,7 +316,7 @@ class AR_NAR(Base):
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
temperature=temperature * annealing,
|
temperature=sampling_temperature,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -884,7 +890,7 @@ def example_usage():
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
|
|
||||||
cfg.model.experimental.masking_train_p = 0.5
|
# cfg.model.experimental.masking_train_p = 0.5
|
||||||
cfg.hyperparameters.batch_size = 1
|
cfg.hyperparameters.batch_size = 1
|
||||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||||
|
|
||||||
|
@ -903,7 +909,7 @@ def example_usage():
|
||||||
|
|
||||||
text_list = [ text ] * batch_size
|
text_list = [ text ] * batch_size
|
||||||
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
||||||
resps_list = [ audio ] * batch_size
|
resps_list = [ audio[:cfg.dataset.frames_per_second * 4, :] ] * batch_size
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_text_tokens': 256,
|
'n_text_tokens': 256,
|
||||||
|
@ -922,10 +928,10 @@ def example_usage():
|
||||||
}
|
}
|
||||||
|
|
||||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||||
available_tasks = ["tts-ar", "tts-nar"]
|
available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(cfg.device)
|
model = AR_NAR(**kwargs).to(cfg.device)
|
||||||
steps = 500 // batch_size
|
steps = 1000 // batch_size
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
@ -1035,7 +1041,7 @@ def example_usage():
|
||||||
if task == "stt":
|
if task == "stt":
|
||||||
prom = [ task ]
|
prom = [ task ]
|
||||||
else:
|
else:
|
||||||
task = "tts" if random.random() > 0.1 else "len"
|
task = "tts" if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
|
||||||
|
|
||||||
texts.append( text )
|
texts.append( text )
|
||||||
proms.append( prom )
|
proms.append( prom )
|
||||||
|
@ -1053,7 +1059,7 @@ def example_usage():
|
||||||
if task == "tts-nar":
|
if task == "tts-nar":
|
||||||
len_list = engine(text_list, proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
len_list = engine(text_list, proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
||||||
len_list = [ resp_list[0].shape[0] for l in len_list ]
|
len_list = [ resp_list[0].shape[0] for l in len_list ]
|
||||||
resps_list = engine( text_list, proms_list, len_list=len_list, temperature=0.0 )
|
resps_list = engine( text_list, proms_list, len_list=len_list )
|
||||||
else:
|
else:
|
||||||
resps_list = engine( text_list, proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
resps_list = engine( text_list, proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, temperature=0.0 )
|
resps_list = engine( text_list, proms_list, resps_list=resps_list, temperature=0.0 )
|
||||||
|
|
|
@ -48,6 +48,12 @@ from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||||
|
|
||||||
summed_embeddings_task = [ "stt" ]
|
summed_embeddings_task = [ "stt" ]
|
||||||
special_tasks = [ "len", "stt" ]
|
special_tasks = [ "len", "stt" ]
|
||||||
|
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
|
||||||
|
task_outputs = {
|
||||||
|
"tts": "resp",
|
||||||
|
"stt": "text",
|
||||||
|
"len": "len",
|
||||||
|
}
|
||||||
|
|
||||||
def _dropout_mask( input, p=None ):
|
def _dropout_mask( input, p=None ):
|
||||||
# cosine scheduling
|
# cosine scheduling
|
||||||
|
@ -200,13 +206,14 @@ class AudioEmbedding(nn.Module):
|
||||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
|
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
|
||||||
if sums is None:
|
if sums is None:
|
||||||
sums = self.sums
|
sums = self.sums
|
||||||
|
|
||||||
|
if quant_level is None:
|
||||||
|
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||||
|
|
||||||
# handle mapping from name
|
# handle mapping from name
|
||||||
if name in self.names:
|
if name in self.names:
|
||||||
offset = self.names.index( name )
|
offset = self.names.index( name )
|
||||||
|
offset -= quant_level # offset by quant level since it'll iterate up that many levels
|
||||||
if quant_level is None:
|
|
||||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
|
||||||
|
|
||||||
if self.sums and quant_level > 0:
|
if self.sums and quant_level > 0:
|
||||||
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
||||||
|
@ -322,7 +329,7 @@ class Base(nn.Module):
|
||||||
def loss_factor(self, k):
|
def loss_factor(self, k):
|
||||||
if self.config is None:
|
if self.config is None:
|
||||||
return 1.0
|
return 1.0
|
||||||
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
|
return self.config.loss_factor(k)
|
||||||
|
|
||||||
def _prune(self, l: Tensor, stop = None):
|
def _prune(self, l: Tensor, stop = None):
|
||||||
if stop is None:
|
if stop is None:
|
||||||
|
@ -429,13 +436,14 @@ class Base(nn.Module):
|
||||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||||
|
|
||||||
|
masking_ratio_fixed = self.config.experimental.masking_ratio_fixed if self.config is not None else False
|
||||||
|
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||||
|
|
||||||
layerskip = self.config.experimental.layerskip if self.config is not None else False
|
layerskip = self.config.experimental.layerskip if self.config is not None else False
|
||||||
layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2
|
layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2
|
||||||
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
|
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
|
||||||
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
|
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
|
||||||
|
|
||||||
masking_separate_embeddings = self.config.experimental.masking_separate_embeddings if self.config is not None else False
|
|
||||||
|
|
||||||
n_tasks = self.config.tasks if self.config is not None else 8
|
n_tasks = self.config.tasks if self.config is not None else 8
|
||||||
n_langs = self.config.langs if self.config is not None else 2
|
n_langs = self.config.langs if self.config is not None else 2
|
||||||
n_tones = self.config.tones if self.config is not None else 1
|
n_tones = self.config.tones if self.config is not None else 1
|
||||||
|
@ -446,15 +454,15 @@ class Base(nn.Module):
|
||||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||||
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||||
# NAR-len model
|
# NAR-len model
|
||||||
elif "len" in self.capabilities and masking_separate_embeddings:
|
elif "len" in self.capabilities:
|
||||||
# +1 to include the stop or mask token
|
# +1 to include the stop or mask token
|
||||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
if "ar" in self.capabilities:
|
||||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||||
|
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||||
if masking_separate_embeddings:
|
else:
|
||||||
l_tokens += [n_resp_tokens]
|
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
resp_l_names += ['NAR:0:0']
|
resp_l_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||||
# AR+NAR model
|
# AR+NAR model
|
||||||
else:
|
else:
|
||||||
# +1 to include the stop or mask token
|
# +1 to include the stop or mask token
|
||||||
|
@ -462,13 +470,19 @@ class Base(nn.Module):
|
||||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||||
|
|
||||||
classifier_l_names = resp_l_names + ["stt"]
|
classifier_l_tokens = l_tokens + [ n_text_tokens ]
|
||||||
|
classifier_l_names = resp_l_names + [ "len" ]
|
||||||
|
|
||||||
|
if "len" in self.capabilities and False:
|
||||||
|
classifier_l_tokens += [ n_text_tokens ]
|
||||||
|
classifier_l_names += ["len"]
|
||||||
|
|
||||||
self.unified_position_ids = unified_position_ids
|
self.unified_position_ids = unified_position_ids
|
||||||
self.interleave = interleave
|
self.interleave = interleave
|
||||||
self.layerskip = layerskip
|
self.layerskip = layerskip
|
||||||
self.inject_timestep_embedding = False # results in bad output
|
self.inject_timestep_embedding = False # results in bad output
|
||||||
self.masking_separate_embeddings = masking_separate_embeddings
|
self.masking_ratio_fixed = masking_ratio_fixed
|
||||||
|
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
|
@ -805,10 +819,10 @@ class Base(nn.Module):
|
||||||
self.metrics = None
|
self.metrics = None
|
||||||
else:
|
else:
|
||||||
self.classifier = None
|
self.classifier = None
|
||||||
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model, l_names=classifier_l_names )
|
self.classifiers = Classifiers( classifier_l_tokens, d_model, l_names=classifier_l_names )
|
||||||
self.accuracy_metric = None
|
self.accuracy_metric = None
|
||||||
self.precision_metric = None
|
self.precision_metric = None
|
||||||
self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
|
self.metrics = Metrics( classifier_l_tokens )
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if tie_classifier_to_embedding:
|
if tie_classifier_to_embedding:
|
||||||
|
@ -996,24 +1010,22 @@ class Base(nn.Module):
|
||||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||||
# insert timestep token
|
# insert timestep token
|
||||||
if timestep is not None:
|
if timestep is not None:
|
||||||
# store timestep information
|
|
||||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
|
||||||
# force set to use this classifier level
|
# force set to use this classifier level
|
||||||
classifier_level = "NAR:0:0" if self.masking_separate_embeddings else "AR:0:0"
|
classifier_level = "NAR:0:0"
|
||||||
|
# a paper said to use a fixed masking ratio for training
|
||||||
|
p = 0.8
|
||||||
|
# store timestep information
|
||||||
|
if not self.masking_ratio_fixed:
|
||||||
|
# cosine scheduled timestep => masking ratio
|
||||||
|
p = math.cos(timestep * math.pi * 0.5)
|
||||||
|
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||||
|
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
|
||||||
|
if self.training:
|
||||||
|
dropout_mask = _dropout_mask( resps_list[i], p )
|
||||||
|
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||||
# insert the current output response
|
# insert the current output response
|
||||||
if resps_list is not None and resps_list[i] is not None:
|
if resps_list is not None and resps_list[i] is not None:
|
||||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||||
|
|
||||||
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
|
|
||||||
if timestep is not None and self.training:
|
|
||||||
"""
|
|
||||||
# a paper said to use a fixed masking ratio for training
|
|
||||||
p = 0.8
|
|
||||||
"""
|
|
||||||
# cosine scheduled timestep => masking ratio
|
|
||||||
p = math.cos(timestep * math.pi * 0.5)
|
|
||||||
dropout_mask = _dropout_mask( resps_list[i], p )
|
|
||||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
|
||||||
|
|
||||||
inputs[i].append( ("classifier_level", classifier_level) )
|
inputs[i].append( ("classifier_level", classifier_level) )
|
||||||
# Audio length prediction task
|
# Audio length prediction task
|
||||||
|
@ -1047,7 +1059,7 @@ class Base(nn.Module):
|
||||||
# yes this could be encoded better
|
# yes this could be encoded better
|
||||||
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
||||||
|
|
||||||
inputs[i].append( ("classifier_level", "stt") )
|
inputs[i].append( ("classifier_level", "len") )
|
||||||
# Speech-to-Text prediction task
|
# Speech-to-Text prediction task
|
||||||
# Sequence: <resp><sep><rvq lvl><sep><text>
|
# Sequence: <resp><sep><rvq lvl><sep><text>
|
||||||
elif task_type == "stt":
|
elif task_type == "stt":
|
||||||
|
@ -1125,6 +1137,7 @@ class Base(nn.Module):
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||||
embedding = None
|
embedding = None
|
||||||
|
|
||||||
# is already an embedding
|
# is already an embedding
|
||||||
if name == "task":
|
if name == "task":
|
||||||
# noop
|
# noop
|
||||||
|
@ -1165,7 +1178,6 @@ class Base(nn.Module):
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
# if masked use masked token, else original token
|
# if masked use masked token, else original token
|
||||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||||
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
|
||||||
#quant_level = 0,
|
#quant_level = 0,
|
||||||
name = classifier_level,
|
name = classifier_level,
|
||||||
)
|
)
|
||||||
|
@ -1174,7 +1186,6 @@ class Base(nn.Module):
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
# if masked use masked token, else original token
|
# if masked use masked token, else original token
|
||||||
input if input.dim() == 1 else input[:, 0],
|
input if input.dim() == 1 else input[:, 0],
|
||||||
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
|
||||||
#quant_level = 0,
|
#quant_level = 0,
|
||||||
name = classifier_level,
|
name = classifier_level,
|
||||||
)
|
)
|
||||||
|
@ -1212,7 +1223,8 @@ class Base(nn.Module):
|
||||||
|
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||||
offset = 0 if classifier_level.startswith("AR:") else 1,
|
#offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||||
|
name = classifier_level,
|
||||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1277,29 +1289,33 @@ class Base(nn.Module):
|
||||||
# there's a better way
|
# there's a better way
|
||||||
if not self.unified_position_ids:
|
if not self.unified_position_ids:
|
||||||
x_list = []
|
x_list = []
|
||||||
non_tokens = ["task", "dropout_mask", "classifier_level"]
|
|
||||||
last_input = ["resp", "len"]
|
|
||||||
|
|
||||||
def get_input_token_length( name, input ):
|
def get_input_token_length( name, input, task ):
|
||||||
# task token
|
# task token
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# list of tokens
|
# list of tokens
|
||||||
if not isinstance(input, torch.Tensor):
|
if not isinstance(input, torch.Tensor):
|
||||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] )
|
||||||
|
|
||||||
# interleaved model
|
# interleaved model
|
||||||
if self.interleave and name == "resp":
|
if self.interleave and name == "resp":
|
||||||
return input.shape[0] * input.shape[1]
|
return input.shape[0] * input.shape[1]
|
||||||
|
|
||||||
# ending input will not have a separator later
|
# ending input will not have a separator later
|
||||||
return input.shape[0] + (0 if name in last_input else 1)
|
return input.shape[0]
|
||||||
|
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
|
# pre-iterate
|
||||||
|
task = "tts"
|
||||||
|
for name, input in batch_input:
|
||||||
|
if name == "task":
|
||||||
|
task = input
|
||||||
|
|
||||||
batch = torch.cat( [
|
batch = torch.cat( [
|
||||||
torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
|
torch.tensor([*range(get_input_token_length(name, input, task) + (1 if name != task_outputs.get(task, name) else 0))], device=device, dtype=torch.int32)
|
||||||
for name, input in batch_input if name not in non_tokens
|
for name, input in batch_input if name not in non_tokened_names
|
||||||
] )
|
] )
|
||||||
|
|
||||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||||
|
@ -1319,8 +1335,8 @@ class Base(nn.Module):
|
||||||
|
|
||||||
quant_levels: list[int] | None = None,
|
quant_levels: list[int] | None = None,
|
||||||
):
|
):
|
||||||
loss = dict(ce = dict())
|
loss = {}
|
||||||
stats = dict(acc = dict())
|
stats = {}
|
||||||
|
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
batch_size = len(logits)
|
batch_size = len(logits)
|
||||||
|
@ -1328,6 +1344,12 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
"""
|
||||||
|
if isinstance(input, str):
|
||||||
|
return torch.tensor( [ self.ignore_index ], device=device, dtype=torch.int16)
|
||||||
|
|
||||||
|
return torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
|
||||||
|
"""
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
|
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
|
||||||
|
|
||||||
|
@ -1336,192 +1358,143 @@ class Base(nn.Module):
|
||||||
return torch.full_like(input[..., 0], self.ignore_index)
|
return torch.full_like(input[..., 0], self.ignore_index)
|
||||||
|
|
||||||
return input if input.dim() == 1 else input[:, quant_level]
|
return input if input.dim() == 1 else input[:, quant_level]
|
||||||
|
|
||||||
# old, "naive" way, no loss factoring
|
for batch_index, batch in enumerate(inputs):
|
||||||
if not self.config.loss_factors:
|
quant_level = quant_levels[batch_index]
|
||||||
target_list = []
|
target = []
|
||||||
task_list = []
|
causal = True
|
||||||
is_causal = []
|
task_type = "tts"
|
||||||
|
|
||||||
for batch_index, batch in enumerate(inputs):
|
|
||||||
quant_level = quant_levels[batch_index]
|
|
||||||
target = []
|
|
||||||
task_type = "tts"
|
|
||||||
|
|
||||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
|
||||||
dropout_mask = None
|
|
||||||
for name, input in batch:
|
|
||||||
if name == "dropout_mask":
|
|
||||||
dropout_mask = input
|
|
||||||
|
|
||||||
for name, input in batch:
|
|
||||||
if name == "task":
|
|
||||||
task_type = input
|
|
||||||
task_list.append( input )
|
|
||||||
if task_type in special_tasks:
|
|
||||||
causal = True
|
|
||||||
elif name == "prom":
|
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
|
||||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
|
||||||
elif name == "resp":
|
|
||||||
# mask found, apply it
|
|
||||||
if dropout_mask is not None:
|
|
||||||
# if mask use original token, else ignore
|
|
||||||
causal = False
|
|
||||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
|
||||||
elif self.interleave:
|
|
||||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
|
||||||
elif task_type in summed_embeddings_task:
|
|
||||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
|
||||||
else:
|
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
|
||||||
elif name == "timestep":
|
|
||||||
target.append( torch.tensor([self.ignore_index], device=input.device) )
|
|
||||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
|
||||||
target.append( input )
|
|
||||||
|
|
||||||
is_causal.append( causal )
|
|
||||||
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
|
||||||
|
|
||||||
batch_size = len(target_list)
|
|
||||||
# modify only causal sequences so it can properly behave like a transformer
|
|
||||||
for i in range(batch_size):
|
|
||||||
quant_level = quant_levels[i]
|
|
||||||
task_name = task_list[i]
|
|
||||||
causal = is_causal[i]
|
|
||||||
|
|
||||||
if causal:
|
|
||||||
l = self.causal_size
|
|
||||||
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
|
||||||
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
|
||||||
|
|
||||||
for batch_index, target in enumerate( target_list ):
|
|
||||||
logit = logits[batch_index]
|
|
||||||
|
|
||||||
max_classes = logit.shape[-1]
|
|
||||||
max_token = torch.max( target ).item()
|
|
||||||
|
|
||||||
if max_token > max_classes:
|
|
||||||
task = self.get_input(inputs, "task", at=batch_index)
|
|
||||||
print( batch_index, task, target, max_token, max_classes, inputs[batch_index] )
|
|
||||||
|
|
||||||
# see comments for the split-loss calc cross_entropy call
|
|
||||||
if False:
|
|
||||||
target = torch.cat( target_list )
|
|
||||||
inputs = torch.cat( logits )
|
|
||||||
loss = dict(
|
|
||||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
|
||||||
)
|
|
||||||
stats = self.metrics( inputs, targets, classifier_levels ) if self.metrics is not None else dict(
|
|
||||||
acc = self.accuracy_metric( inputs, target ),
|
|
||||||
# precision = self.precision_metric( inputs, target ),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# nll being natural log likelihood :)))) (I don't know why this completely escaped me originally with thinking it meant something else)
|
|
||||||
loss = dict(
|
|
||||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
|
||||||
)
|
|
||||||
stats = self.metrics( logits, target_list, self.classifiers.indices( classifier_levels ) ) if self.metrics is not None else dict(
|
|
||||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
|
||||||
|
|
||||||
"""
|
|
||||||
# considerations:
|
|
||||||
# * split losses does not maintain the entire sequence
|
|
||||||
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
|
|
||||||
# + the other way at least should keep it intact this way
|
|
||||||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
|
||||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# to-do: use NAR-len training and better causal-awareness
|
|
||||||
info = {}
|
|
||||||
batch_size = len( inputs )
|
|
||||||
|
|
||||||
for i, batch in enumerate( inputs ):
|
|
||||||
it = 0
|
|
||||||
quant_level = quant_levels[i]
|
|
||||||
task_name = None
|
|
||||||
|
|
||||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
|
||||||
dropout_mask = None
|
dropout_mask = None
|
||||||
for name, input in batch:
|
classifier_level = None
|
||||||
if name == "dropout_mask":
|
|
||||||
dropout_mask = input
|
|
||||||
|
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
# meta-input, no corresponding token at the moment
|
|
||||||
if name == "task":
|
if name == "task":
|
||||||
task_name = input
|
task_type = input
|
||||||
if task_type in special_tasks:
|
elif name == "dropout_mask":
|
||||||
causal = True
|
dropout_mask = input
|
||||||
|
elif name == "classifier_level":
|
||||||
|
classifier_level = input
|
||||||
|
|
||||||
|
# autoregressive, causal
|
||||||
|
if classifier_level.startswith("AR:"):
|
||||||
|
causal = True
|
||||||
|
# nonautoregressive, parallel
|
||||||
|
elif classifier_level.startswith("NAR:"):
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
it = 0
|
||||||
|
for name, input in batch:
|
||||||
|
token = None
|
||||||
|
ignored = False
|
||||||
|
|
||||||
|
# non-tokened tasks
|
||||||
|
if name in non_tokened_names:
|
||||||
continue
|
continue
|
||||||
# do not use resp as-is
|
# prom can either be a tensor itself or a list of tensors and strings
|
||||||
if name == "resp":
|
if name == "prom":
|
||||||
|
# expand to list if not a list
|
||||||
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
|
# iterate over the list to inject their tokens
|
||||||
|
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||||
|
elif name == "resp":
|
||||||
|
# mask found, apply it
|
||||||
if dropout_mask is not None:
|
if dropout_mask is not None:
|
||||||
# if mask use original token, else ignore
|
# if mask use original token, else ignore
|
||||||
causal = False
|
token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index )
|
||||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
# flatten
|
||||||
elif self.interleave:
|
elif self.interleave:
|
||||||
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
token = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
||||||
elif task_type in summed_embeddings_task:
|
# use resps as-is
|
||||||
input = torch.full_like(input[..., 0], self.ignore_index)
|
|
||||||
else:
|
else:
|
||||||
input = input if input.dim() == 1 else input[:, quant_level]
|
token = input if input.dim() == 1 else input[:, quant_level]
|
||||||
# select prom level
|
# not a special input, inject as-is
|
||||||
elif name == "prom":
|
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
|
||||||
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
|
||||||
|
|
||||||
seq_len = input.shape[0]
|
|
||||||
|
|
||||||
logit = logits[i][it:it+seq_len]
|
|
||||||
it += seq_len + 1 # +1 to incorporate the separator
|
|
||||||
|
|
||||||
# for the AR, shift sequence so that it predicts the next token
|
|
||||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
|
||||||
if causal and seq_len > 1:
|
|
||||||
l = self.causal_size
|
|
||||||
logit = logit[..., :-l, :]
|
|
||||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
|
||||||
|
|
||||||
if name not in info:
|
|
||||||
info[name] = {
|
|
||||||
"targets": [],
|
|
||||||
"logits": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
|
|
||||||
info[name]["targets"].append( input.long() )
|
|
||||||
info[name]["logits"].append( logit )
|
|
||||||
|
|
||||||
for name, batch in info.items():
|
|
||||||
loss_factor = self.loss_factor(name)
|
|
||||||
|
|
||||||
if name not in ["text", "prom", "resp", "len"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if loss_factor == 0.0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
|
|
||||||
# to-do: set this to a var
|
|
||||||
if False:
|
|
||||||
targets = torch.cat( batch["targets"] ).long()
|
|
||||||
inputs = torch.cat( batch["logits"] )
|
|
||||||
loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
|
||||||
stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
|
||||||
# probably consumes less memory due to not having to allocate memory
|
|
||||||
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
|
||||||
else:
|
|
||||||
loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
|
||||||
if self.metrics is not None:
|
|
||||||
metrics = self.metrics( batch["logits"], batch["targets"], self.classifiers.indices( classifier_levels ) )
|
|
||||||
stats["acc"][name] = metrics["acc"]
|
|
||||||
else:
|
else:
|
||||||
stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
token = input
|
||||||
|
|
||||||
|
if not isinstance(token, torch.Tensor):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token.is_floating_point():
|
||||||
|
ignored = True
|
||||||
|
|
||||||
|
# grab range of our logits for later
|
||||||
|
seq_len = token.shape[0]
|
||||||
|
start, end = it, it+seq_len
|
||||||
|
it += seq_len + 1 # +1 to incorporate the separator
|
||||||
|
|
||||||
|
# deduce if a name for a task is an input or output
|
||||||
|
if self.ignore_inputs_for_loss and name != task_outputs.get(task_type, name):
|
||||||
|
ignored = True
|
||||||
|
|
||||||
|
if ignored:
|
||||||
|
# pruned
|
||||||
|
if self.config.loss_factors:
|
||||||
|
continue
|
||||||
|
# fill with ignored out tensor
|
||||||
|
token = torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
|
||||||
|
|
||||||
|
# perform loss calculation on the individual piece
|
||||||
|
if self.config.loss_factors:
|
||||||
|
loss_factor = self.loss_factor(name)
|
||||||
|
|
||||||
|
if loss_factor == 0.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logit = logits[batch_index][start:end]
|
||||||
|
|
||||||
|
if causal and seq_len > 1:
|
||||||
|
l = self.causal_size
|
||||||
|
logit = logit[..., :-l, :]
|
||||||
|
token = token[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||||
|
|
||||||
|
if f'{name}.nll' not in loss:
|
||||||
|
loss[f'{name}.nll'] = []
|
||||||
|
|
||||||
|
if f'{name}.acc' not in stats:
|
||||||
|
stats[f'{name}.acc'] = []
|
||||||
|
|
||||||
|
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
|
||||||
|
if self.metrics is not None:
|
||||||
|
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||||
|
else:
|
||||||
|
metrics = self.accuracy_metric( logit, token )
|
||||||
|
|
||||||
|
loss[f'{name}.nll'].append( nll )
|
||||||
|
stats[f'{name}.acc'].append( metrics )
|
||||||
|
# add to list
|
||||||
|
else:
|
||||||
|
target.append( token )
|
||||||
|
|
||||||
|
# perofrm loss calculation on the entire sequence
|
||||||
|
if not self.config.loss_factors:
|
||||||
|
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||||
|
logit = logits[batch_index]
|
||||||
|
|
||||||
|
# shift if causal
|
||||||
|
if causal:
|
||||||
|
l = self.causal_size
|
||||||
|
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||||
|
target = target[..., l:] # ...predicts token n + 1
|
||||||
|
|
||||||
|
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
|
||||||
|
|
||||||
|
if self.metrics is not None:
|
||||||
|
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
|
||||||
|
else:
|
||||||
|
metrics = self.accuracy_metric( logit, target )
|
||||||
|
|
||||||
|
if 'nll' not in loss:
|
||||||
|
loss['nll'] = []
|
||||||
|
|
||||||
|
if 'acc' not in stats:
|
||||||
|
stats['acc'] = []
|
||||||
|
|
||||||
|
loss["nll"].append( nll )
|
||||||
|
stats["acc"].append( metrics )
|
||||||
|
|
||||||
|
# average
|
||||||
|
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||||
|
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
return LossStats(loss, stats)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user