redid loss calculation to be cleaner, and position ID generation, and other things (I might need to train the NAR-len from scratch and not resume from an existing checkpoint.........)
This commit is contained in:
parent
ef05c951ff
commit
39096f8ff3
|
@ -67,6 +67,8 @@ I'm uncertain on how to remedy this, as my options are:
|
|||
* train an AR, and train a NAR, if the codec itself is at fault
|
||||
* use an SSM like Mamba, if transformers entirely cannot model the codec
|
||||
* train a separate model that simply converts from EnCodec to DAC
|
||||
* train *all* NAR levels as independent masking sequences.
|
||||
* if this works, then it means that there's little to no mappable relation between DAC's RVQ levels
|
||||
|
||||
## `transcribe.py`
|
||||
|
||||
|
|
|
@ -48,17 +48,11 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
|
|||
The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
|
||||
|
||||
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
||||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
|
||||
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
|
||||
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
|
||||
|
||||
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||
|
||||
The reference model provided has *some* NAR demasking (mock diffusion) aware training to faciliate a pure NAR model, but:
|
||||
* Sampling absolutely requires rep pen, or the output degenerates.
|
||||
* Output isn't so great, as there's artifacting from either an underbaked model or a naive sampler.
|
||||
* The former problem is easily "solved" by training a `len` classification task.
|
||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible (but can be done in multiple steps).
|
||||
* diffusion solves this, but requires a different underliny model architecture
|
||||
* masking to emulate diffusion noising is best working solution, but has a lot of training challenges.
|
||||
* existing solutions like Muse (text to image) and MaskGCT (text to speech) do this
|
||||
|
||||
To-do: fill out this more when it works. Getting this to work is a huge pain.
|
||||
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
|
||||
|
@ -67,22 +61,28 @@ To-do: fill out this more when it works. Getting this to work is a huge pain.
|
|||
* MaskGCT does it both pre and post
|
||||
* the test trainier actually degrades the output immensely when doing this
|
||||
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do.
|
||||
* This helps the base AR+NAR tasks and provides CFG sampling for such tasks anyways.
|
||||
* ***Extreme*** care is required.
|
||||
|
||||
## Embeddings
|
||||
## Embeddings (and Classifiers)
|
||||
|
||||
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence.
|
||||
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can get away with treating each sequence as separate ranges within a total token sequence.
|
||||
|
||||
While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed.
|
||||
|
||||
With attention-based transformers, most embeddings can serve as a token itself and have the attention mechanism attend to it. Theoretically, there should be little to no functional differences between "tokenizing" an embedding, and summing a modifying embedding, but experimentation is needed for this assertion.
|
||||
|
||||
### Classifiers
|
||||
|
||||
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
|
||||
|
||||
Out of paranoia, each head is split for each macro-task (RVQ level) and an auxiliary head for tasks `stt` and `len`, even though the core half of the model's training was with a single output head.
|
||||
|
||||
### Text Embeddings
|
||||
|
||||
The input text phonemes (or output for STT) are passed through an embedding head (`text`), similar to how a normal text LLM would. Nothing fancy is required, as it's very straightforward.
|
||||
|
||||
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
|
||||
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes.
|
||||
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" text embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
|
||||
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes, as language-specific pure text embeddings can be trained.
|
||||
|
||||
These embeddings *could* instead be added on top of the input prompt embedding instead of serving as additional tasks (similar to injecting position embeddings), but additional experimentation is required to see if the model both can work under this and/or benefits from this.
|
||||
|
||||
|
@ -115,16 +115,15 @@ As EnCodec encodes audio across eight codebooks (and DAC's 44Khz audio under nin
|
|||
For the `prom` embedding, we can simply use each embedding for each layer. Each embedding level maps to its respective RVQ level.
|
||||
|
||||
Howver, the `resp` requires some extra care, as the model needs to both causally (AR) and parallel-ly (NAR) decode tokens.
|
||||
* The first embedding level pertains to RVQ level 0 for the AR.
|
||||
* The remaining embedding levels maps to RVQ level 0 + n for the NAR.
|
||||
* The first embedding level pertains to RVQ level 0 for the AR (`AR:0:0`).
|
||||
* This embedding predicts tokens within its own embedding.
|
||||
* The remaining embedding levels maps to RVQ level 0 + n for the NAR (`NAR:L-1:L`).
|
||||
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
|
||||
* I believe this is because the model needs to "know" whether to predict ~~the next token in the sequence, or the token in the same position of the next RVQ level~~ which tokens of a given embedding.
|
||||
* In other words, the AR's RVQ level 0 embedding predicts itself, while the NAR's embeddings predict the next level's embeddings.
|
||||
* This is evident on how RVQ level 0 can be trained causally and in parallel with its own embeddings, rather than having limiting issues when reusing the embedding across the two.
|
||||
* Unfortunately, providing a token for the current/target RVQ level within the input sequence doesn't seem to help? I don't remember if I experimented with this or not, but testing of a "sane" `resp` embedding proved to be unfruitful.
|
||||
* I believe this is required because the model encodes which task to perform (rather than the attention heads), and which tokens to predict (rather than the classifiers)
|
||||
* In other words, each embedding needs to be separated based on what tokens they do predict.
|
||||
|
||||
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.
|
||||
* I have yet to conduct tests with interchanging the `prom` and `resp`, and the model definitely exhibits being able to map from the `prom` directly, and being able to inference from the `prom` being prefixed in the `resp`.
|
||||
* The `text` embedding's robustness not only for reusing between each RVQ level, but as STT task as well is a mystery.
|
||||
|
||||
Finally, the model *may* then sum each embedding level back down to one sequence, as defined under `cfg.model.experimental.audio_embedding_sums`.
|
||||
* The resulant sum is not normalized by the length.
|
||||
|
|
|
@ -27,8 +27,20 @@ Some additional flags can be passed as well:
|
|||
|
||||
A training paradigm that works for me is:
|
||||
* setting the dataloader to sort by duration, then training one epoch, so the model starts with small utterances then trains to larger ones.
|
||||
* the daring can wait until coherent speech emerges, then move to the next step
|
||||
* some additional training using a shuffled dataloader, as the model will be fixated towards whatever duration range it was trained under.
|
||||
* additional training for sampling per speaker, to better help diversify how well it can perform for a range of speakers, rather than just speaking itself
|
||||
* I don't think this is crucial, but speaker-based sampling seems to be a huge placebo if anything.
|
||||
|
||||
I don't remember the exact numbers off the top of my head, but a good loss/accuracy/gradient norm to look out for when coherent speech emergies are:
|
||||
* loss <3.0
|
||||
* acc >0.7
|
||||
* grad_norm <0.2
|
||||
|
||||
Training under `float16` should be fairly simple, but care is required to keep the loss scaling factor above 8K, and probably even 16K.
|
||||
* At the very least for pre-trained models, low enough loss scales will irreparably fry the model, and no amount of training afterwards seems to "fix" it.
|
||||
* The current DeepSpeed configuration should keep the loss scale capped to 32K, but this so far is only validated for pre-trained models.
|
||||
* Training under `bfloat16` does not have to worry about this as there's no need for loss scaling, but I feel the model performs better when trained under `float16`+AMP rather than `bfloat16` (with or without AMP).
|
||||
|
||||
## Try Me
|
||||
|
||||
|
|
|
@ -260,7 +260,9 @@ class ModelExperimentalSettings:
|
|||
|
||||
masking_train_p: float = 0.0 # odds of training with masking
|
||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||
masking_separate_embeddings: bool = False # to-do: explain
|
||||
|
||||
masking_ratio_fixed: bool = False
|
||||
ignore_inputs_for_loss: bool = False
|
||||
|
||||
# classifier-free guidance shit
|
||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||
|
@ -301,7 +303,7 @@ class Model:
|
|||
return [ self ] if not name or self.name == name else []
|
||||
|
||||
def loss_factor(self, k):
|
||||
return self.loss_factors[k] if k in self.loss_factors else 1.0
|
||||
return self.loss_factors.get(k, 0.0)
|
||||
|
||||
@property
|
||||
def max_levels(self):
|
||||
|
@ -508,6 +510,9 @@ class DeepSpeed:
|
|||
|
||||
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
||||
|
||||
loss_scale_window: int = 100
|
||||
min_loss_scale: float = 8192.0
|
||||
|
||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
@cached_property
|
||||
|
@ -558,8 +563,8 @@ class DeepSpeed:
|
|||
"fp16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
||||
"auto_cast": True, # ???
|
||||
"loss_scale_window": 100, # raise every 100 consecutive good steps
|
||||
"min_loss_scale": 32768.0, # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy
|
||||
"loss_scale_window": self.loss_scale_window, # raise every 100 consecutive good steps
|
||||
"min_loss_scale": self.min_loss_scale, # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy
|
||||
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
|
||||
},
|
||||
"bf16": {
|
||||
|
|
|
@ -1761,12 +1761,12 @@ if __name__ == "__main__":
|
|||
elif args.action == "metadata":
|
||||
create_dataset_metadata()
|
||||
elif args.action == "sample":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
train_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
samples = {
|
||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
}
|
||||
|
||||
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
||||
|
@ -1784,8 +1784,8 @@ if __name__ == "__main__":
|
|||
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
|
||||
v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||
v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||
#v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||
#v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
|
|
|
@ -269,7 +269,6 @@ class TTS():
|
|||
)
|
||||
elif model_len is not None:
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that
|
||||
|
||||
kwargs = {}
|
||||
# nasty hardcode to load a reference file and have that as the input target
|
||||
if load_from_artifact and load_from_artifact.exists():
|
||||
|
|
|
@ -148,7 +148,7 @@ class AR_NAR(Base):
|
|||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0 and timesteps[i] is not None:
|
||||
if quant_level <= 0 and timesteps[i] is None:
|
||||
# append stop tokens for AR
|
||||
if task in text_task:
|
||||
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
||||
|
@ -254,9 +254,8 @@ class AR_NAR(Base):
|
|||
null_prom = [ None for _ in range(batch_size) ]
|
||||
prev_list = resps_list
|
||||
|
||||
for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps):
|
||||
#for noise_p, annealed_temperature, temperature, cfg_strength in zip( manual_ratios, manual_temp, manual_samp_temp, manual_cfg ):
|
||||
annealing = (steps_until_x0 / max_steps)
|
||||
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
||||
annealing = 1.0 - timestep
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# pick the worst scoring tokens to mask off
|
||||
|
@ -268,6 +267,13 @@ class AR_NAR(Base):
|
|||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
"""
|
||||
sampling_temperature = temperature * annealing
|
||||
sampling_cfg = cfg_strength * timestep
|
||||
"""
|
||||
sampling_temperature = temperature
|
||||
sampling_cfg = cfg_strength
|
||||
|
||||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
text_list=text_list,
|
||||
|
@ -302,7 +308,7 @@ class AR_NAR(Base):
|
|||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * (cfg_strength * timestep)
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * sampling_cfg
|
||||
|
||||
# sample with sampler settings
|
||||
filtered_sampled = super().sample(
|
||||
|
@ -310,7 +316,7 @@ class AR_NAR(Base):
|
|||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=temperature * annealing,
|
||||
temperature=sampling_temperature,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
|
@ -884,7 +890,7 @@ def example_usage():
|
|||
import numpy as np
|
||||
import re
|
||||
|
||||
cfg.model.experimental.masking_train_p = 0.5
|
||||
# cfg.model.experimental.masking_train_p = 0.5
|
||||
cfg.hyperparameters.batch_size = 1
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
|
||||
|
@ -903,7 +909,7 @@ def example_usage():
|
|||
|
||||
text_list = [ text ] * batch_size
|
||||
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
||||
resps_list = [ audio ] * batch_size
|
||||
resps_list = [ audio[:cfg.dataset.frames_per_second * 4, :] ] * batch_size
|
||||
|
||||
kwargs = {
|
||||
'n_text_tokens': 256,
|
||||
|
@ -922,10 +928,10 @@ def example_usage():
|
|||
}
|
||||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
available_tasks = ["tts-ar", "tts-nar"]
|
||||
available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
|
||||
|
||||
model = AR_NAR(**kwargs).to(cfg.device)
|
||||
steps = 500 // batch_size
|
||||
steps = 1000 // batch_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
@ -1035,7 +1041,7 @@ def example_usage():
|
|||
if task == "stt":
|
||||
prom = [ task ]
|
||||
else:
|
||||
task = "tts" if random.random() > 0.1 else "len"
|
||||
task = "tts" if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
|
||||
|
||||
texts.append( text )
|
||||
proms.append( prom )
|
||||
|
@ -1053,7 +1059,7 @@ def example_usage():
|
|||
if task == "tts-nar":
|
||||
len_list = engine(text_list, proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
||||
len_list = [ resp_list[0].shape[0] for l in len_list ]
|
||||
resps_list = engine( text_list, proms_list, len_list=len_list, temperature=0.0 )
|
||||
resps_list = engine( text_list, proms_list, len_list=len_list )
|
||||
else:
|
||||
resps_list = engine( text_list, proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, temperature=0.0 )
|
||||
|
|
|
@ -48,6 +48,12 @@ from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
|||
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
special_tasks = [ "len", "stt" ]
|
||||
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
|
||||
task_outputs = {
|
||||
"tts": "resp",
|
||||
"stt": "text",
|
||||
"len": "len",
|
||||
}
|
||||
|
||||
def _dropout_mask( input, p=None ):
|
||||
# cosine scheduling
|
||||
|
@ -200,13 +206,14 @@ class AudioEmbedding(nn.Module):
|
|||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
|
||||
if sums is None:
|
||||
sums = self.sums
|
||||
|
||||
if quant_level is None:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
# handle mapping from name
|
||||
if name in self.names:
|
||||
offset = self.names.index( name )
|
||||
|
||||
if quant_level is None:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
offset -= quant_level # offset by quant level since it'll iterate up that many levels
|
||||
|
||||
if self.sums and quant_level > 0:
|
||||
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
||||
|
@ -322,7 +329,7 @@ class Base(nn.Module):
|
|||
def loss_factor(self, k):
|
||||
if self.config is None:
|
||||
return 1.0
|
||||
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
|
||||
return self.config.loss_factor(k)
|
||||
|
||||
def _prune(self, l: Tensor, stop = None):
|
||||
if stop is None:
|
||||
|
@ -429,13 +436,14 @@ class Base(nn.Module):
|
|||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
|
||||
masking_ratio_fixed = self.config.experimental.masking_ratio_fixed if self.config is not None else False
|
||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||
|
||||
layerskip = self.config.experimental.layerskip if self.config is not None else False
|
||||
layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2
|
||||
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
|
||||
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
|
||||
|
||||
masking_separate_embeddings = self.config.experimental.masking_separate_embeddings if self.config is not None else False
|
||||
|
||||
n_tasks = self.config.tasks if self.config is not None else 8
|
||||
n_langs = self.config.langs if self.config is not None else 2
|
||||
n_tones = self.config.tones if self.config is not None else 1
|
||||
|
@ -446,15 +454,15 @@ class Base(nn.Module):
|
|||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities and masking_separate_embeddings:
|
||||
elif "len" in self.capabilities:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
|
||||
if masking_separate_embeddings:
|
||||
l_tokens += [n_resp_tokens]
|
||||
resp_l_names += ['NAR:0:0']
|
||||
if "ar" in self.capabilities:
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
else:
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
resp_l_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
# AR+NAR model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
|
@ -462,13 +470,19 @@ class Base(nn.Module):
|
|||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
|
||||
classifier_l_names = resp_l_names + ["stt"]
|
||||
classifier_l_tokens = l_tokens + [ n_text_tokens ]
|
||||
classifier_l_names = resp_l_names + [ "len" ]
|
||||
|
||||
if "len" in self.capabilities and False:
|
||||
classifier_l_tokens += [ n_text_tokens ]
|
||||
classifier_l_names += ["len"]
|
||||
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
self.layerskip = layerskip
|
||||
self.inject_timestep_embedding = False # results in bad output
|
||||
self.masking_separate_embeddings = masking_separate_embeddings
|
||||
self.masking_ratio_fixed = masking_ratio_fixed
|
||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -805,10 +819,10 @@ class Base(nn.Module):
|
|||
self.metrics = None
|
||||
else:
|
||||
self.classifier = None
|
||||
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model, l_names=classifier_l_names )
|
||||
self.classifiers = Classifiers( classifier_l_tokens, d_model, l_names=classifier_l_names )
|
||||
self.accuracy_metric = None
|
||||
self.precision_metric = None
|
||||
self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
|
||||
self.metrics = Metrics( classifier_l_tokens )
|
||||
|
||||
"""
|
||||
if tie_classifier_to_embedding:
|
||||
|
@ -996,24 +1010,22 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
# insert timestep token
|
||||
if timestep is not None:
|
||||
# store timestep information
|
||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
# force set to use this classifier level
|
||||
classifier_level = "NAR:0:0" if self.masking_separate_embeddings else "AR:0:0"
|
||||
classifier_level = "NAR:0:0"
|
||||
# a paper said to use a fixed masking ratio for training
|
||||
p = 0.8
|
||||
# store timestep information
|
||||
if not self.masking_ratio_fixed:
|
||||
# cosine scheduled timestep => masking ratio
|
||||
p = math.cos(timestep * math.pi * 0.5)
|
||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
|
||||
if self.training:
|
||||
dropout_mask = _dropout_mask( resps_list[i], p )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
# insert the current output response
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
|
||||
if timestep is not None and self.training:
|
||||
"""
|
||||
# a paper said to use a fixed masking ratio for training
|
||||
p = 0.8
|
||||
"""
|
||||
# cosine scheduled timestep => masking ratio
|
||||
p = math.cos(timestep * math.pi * 0.5)
|
||||
dropout_mask = _dropout_mask( resps_list[i], p )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", classifier_level) )
|
||||
# Audio length prediction task
|
||||
|
@ -1047,7 +1059,7 @@ class Base(nn.Module):
|
|||
# yes this could be encoded better
|
||||
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "stt") )
|
||||
inputs[i].append( ("classifier_level", "len") )
|
||||
# Speech-to-Text prediction task
|
||||
# Sequence: <resp><sep><rvq lvl><sep><text>
|
||||
elif task_type == "stt":
|
||||
|
@ -1125,6 +1137,7 @@ class Base(nn.Module):
|
|||
for name, input in batch_input:
|
||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||
embedding = None
|
||||
|
||||
# is already an embedding
|
||||
if name == "task":
|
||||
# noop
|
||||
|
@ -1165,7 +1178,6 @@ class Base(nn.Module):
|
|||
embedding = self.resps_emb(
|
||||
# if masked use masked token, else original token
|
||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
||||
#quant_level = 0,
|
||||
name = classifier_level,
|
||||
)
|
||||
|
@ -1174,7 +1186,6 @@ class Base(nn.Module):
|
|||
embedding = self.resps_emb(
|
||||
# if masked use masked token, else original token
|
||||
input if input.dim() == 1 else input[:, 0],
|
||||
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
||||
#quant_level = 0,
|
||||
name = classifier_level,
|
||||
)
|
||||
|
@ -1212,7 +1223,8 @@ class Base(nn.Module):
|
|||
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||
#offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||
name = classifier_level,
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
)
|
||||
|
||||
|
@ -1277,29 +1289,33 @@ class Base(nn.Module):
|
|||
# there's a better way
|
||||
if not self.unified_position_ids:
|
||||
x_list = []
|
||||
non_tokens = ["task", "dropout_mask", "classifier_level"]
|
||||
last_input = ["resp", "len"]
|
||||
|
||||
def get_input_token_length( name, input ):
|
||||
def get_input_token_length( name, input, task ):
|
||||
# task token
|
||||
if isinstance(input, str):
|
||||
return 1
|
||||
|
||||
# list of tokens
|
||||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] )
|
||||
|
||||
# interleaved model
|
||||
if self.interleave and name == "resp":
|
||||
return input.shape[0] * input.shape[1]
|
||||
|
||||
# ending input will not have a separator later
|
||||
return input.shape[0] + (0 if name in last_input else 1)
|
||||
return input.shape[0]
|
||||
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
# pre-iterate
|
||||
task = "tts"
|
||||
for name, input in batch_input:
|
||||
if name == "task":
|
||||
task = input
|
||||
|
||||
batch = torch.cat( [
|
||||
torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
|
||||
for name, input in batch_input if name not in non_tokens
|
||||
torch.tensor([*range(get_input_token_length(name, input, task) + (1 if name != task_outputs.get(task, name) else 0))], device=device, dtype=torch.int32)
|
||||
for name, input in batch_input if name not in non_tokened_names
|
||||
] )
|
||||
|
||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||
|
@ -1319,8 +1335,8 @@ class Base(nn.Module):
|
|||
|
||||
quant_levels: list[int] | None = None,
|
||||
):
|
||||
loss = dict(ce = dict())
|
||||
stats = dict(acc = dict())
|
||||
loss = {}
|
||||
stats = {}
|
||||
|
||||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
|
@ -1328,6 +1344,12 @@ class Base(nn.Module):
|
|||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ self.ignore_index ], device=device, dtype=torch.int16)
|
||||
|
||||
return torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
|
||||
|
||||
|
@ -1336,192 +1358,143 @@ class Base(nn.Module):
|
|||
return torch.full_like(input[..., 0], self.ignore_index)
|
||||
|
||||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
# old, "naive" way, no loss factoring
|
||||
if not self.config.loss_factors:
|
||||
target_list = []
|
||||
task_list = []
|
||||
is_causal = []
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
target = []
|
||||
task_type = "tts"
|
||||
|
||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||
dropout_mask = None
|
||||
for name, input in batch:
|
||||
if name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
task_list.append( input )
|
||||
if task_type in special_tasks:
|
||||
causal = True
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
||||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
# if mask use original token, else ignore
|
||||
causal = False
|
||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
||||
elif self.interleave:
|
||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||
elif task_type in summed_embeddings_task:
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
else:
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||
elif name == "timestep":
|
||||
target.append( torch.tensor([self.ignore_index], device=input.device) )
|
||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||
target.append( input )
|
||||
|
||||
is_causal.append( causal )
|
||||
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
||||
|
||||
batch_size = len(target_list)
|
||||
# modify only causal sequences so it can properly behave like a transformer
|
||||
for i in range(batch_size):
|
||||
quant_level = quant_levels[i]
|
||||
task_name = task_list[i]
|
||||
causal = is_causal[i]
|
||||
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
||||
|
||||
for batch_index, target in enumerate( target_list ):
|
||||
logit = logits[batch_index]
|
||||
|
||||
max_classes = logit.shape[-1]
|
||||
max_token = torch.max( target ).item()
|
||||
|
||||
if max_token > max_classes:
|
||||
task = self.get_input(inputs, "task", at=batch_index)
|
||||
print( batch_index, task, target, max_token, max_classes, inputs[batch_index] )
|
||||
|
||||
# see comments for the split-loss calc cross_entropy call
|
||||
if False:
|
||||
target = torch.cat( target_list )
|
||||
inputs = torch.cat( logits )
|
||||
loss = dict(
|
||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||
)
|
||||
stats = self.metrics( inputs, targets, classifier_levels ) if self.metrics is not None else dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
else:
|
||||
# nll being natural log likelihood :)))) (I don't know why this completely escaped me originally with thinking it meant something else)
|
||||
loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
stats = self.metrics( logits, target_list, self.classifiers.indices( classifier_levels ) ) if self.metrics is not None else dict(
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||
)
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
"""
|
||||
# considerations:
|
||||
# * split losses does not maintain the entire sequence
|
||||
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
|
||||
# + the other way at least should keep it intact this way
|
||||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||
"""
|
||||
|
||||
# to-do: use NAR-len training and better causal-awareness
|
||||
info = {}
|
||||
batch_size = len( inputs )
|
||||
|
||||
for i, batch in enumerate( inputs ):
|
||||
it = 0
|
||||
quant_level = quant_levels[i]
|
||||
task_name = None
|
||||
|
||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
target = []
|
||||
causal = True
|
||||
task_type = "tts"
|
||||
dropout_mask = None
|
||||
for name, input in batch:
|
||||
if name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
classifier_level = None
|
||||
|
||||
for name, input in batch:
|
||||
# meta-input, no corresponding token at the moment
|
||||
if name == "task":
|
||||
task_name = input
|
||||
if task_type in special_tasks:
|
||||
causal = True
|
||||
task_type = input
|
||||
elif name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "classifier_level":
|
||||
classifier_level = input
|
||||
|
||||
# autoregressive, causal
|
||||
if classifier_level.startswith("AR:"):
|
||||
causal = True
|
||||
# nonautoregressive, parallel
|
||||
elif classifier_level.startswith("NAR:"):
|
||||
causal = False
|
||||
|
||||
it = 0
|
||||
for name, input in batch:
|
||||
token = None
|
||||
ignored = False
|
||||
|
||||
# non-tokened tasks
|
||||
if name in non_tokened_names:
|
||||
continue
|
||||
# do not use resp as-is
|
||||
if name == "resp":
|
||||
# prom can either be a tensor itself or a list of tensors and strings
|
||||
if name == "prom":
|
||||
# expand to list if not a list
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
# iterate over the list to inject their tokens
|
||||
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
# if mask use original token, else ignore
|
||||
causal = False
|
||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
||||
token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index )
|
||||
# flatten
|
||||
elif self.interleave:
|
||||
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
||||
elif task_type in summed_embeddings_task:
|
||||
input = torch.full_like(input[..., 0], self.ignore_index)
|
||||
token = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
||||
# use resps as-is
|
||||
else:
|
||||
input = input if input.dim() == 1 else input[:, quant_level]
|
||||
# select prom level
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
||||
|
||||
seq_len = input.shape[0]
|
||||
|
||||
logit = logits[i][it:it+seq_len]
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# for the AR, shift sequence so that it predicts the next token
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
if causal and seq_len > 1:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :]
|
||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
||||
if name not in info:
|
||||
info[name] = {
|
||||
"targets": [],
|
||||
"logits": [],
|
||||
}
|
||||
|
||||
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
|
||||
info[name]["targets"].append( input.long() )
|
||||
info[name]["logits"].append( logit )
|
||||
|
||||
for name, batch in info.items():
|
||||
loss_factor = self.loss_factor(name)
|
||||
|
||||
if name not in ["text", "prom", "resp", "len"]:
|
||||
continue
|
||||
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
|
||||
# to-do: set this to a var
|
||||
if False:
|
||||
targets = torch.cat( batch["targets"] ).long()
|
||||
inputs = torch.cat( batch["logits"] )
|
||||
loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
||||
stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
||||
# probably consumes less memory due to not having to allocate memory
|
||||
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
||||
else:
|
||||
loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics( batch["logits"], batch["targets"], self.classifiers.indices( classifier_levels ) )
|
||||
stats["acc"][name] = metrics["acc"]
|
||||
token = input if input.dim() == 1 else input[:, quant_level]
|
||||
# not a special input, inject as-is
|
||||
else:
|
||||
stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||
token = input
|
||||
|
||||
if not isinstance(token, torch.Tensor):
|
||||
continue
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
||||
# grab range of our logits for later
|
||||
seq_len = token.shape[0]
|
||||
start, end = it, it+seq_len
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# deduce if a name for a task is an input or output
|
||||
if self.ignore_inputs_for_loss and name != task_outputs.get(task_type, name):
|
||||
ignored = True
|
||||
|
||||
if ignored:
|
||||
# pruned
|
||||
if self.config.loss_factors:
|
||||
continue
|
||||
# fill with ignored out tensor
|
||||
token = torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
|
||||
|
||||
# perform loss calculation on the individual piece
|
||||
if self.config.loss_factors:
|
||||
loss_factor = self.loss_factor(name)
|
||||
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
logit = logits[batch_index][start:end]
|
||||
|
||||
if causal and seq_len > 1:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :]
|
||||
token = token[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
||||
if f'{name}.nll' not in loss:
|
||||
loss[f'{name}.nll'] = []
|
||||
|
||||
if f'{name}.acc' not in stats:
|
||||
stats[f'{name}.acc'] = []
|
||||
|
||||
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
metrics = self.accuracy_metric( logit, token )
|
||||
|
||||
loss[f'{name}.nll'].append( nll )
|
||||
stats[f'{name}.acc'].append( metrics )
|
||||
# add to list
|
||||
else:
|
||||
target.append( token )
|
||||
|
||||
# perofrm loss calculation on the entire sequence
|
||||
if not self.config.loss_factors:
|
||||
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
logit = logits[batch_index]
|
||||
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
target = target[..., l:] # ...predicts token n + 1
|
||||
|
||||
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
|
||||
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
metrics = self.accuracy_metric( logit, target )
|
||||
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
|
||||
loss["nll"].append( nll )
|
||||
stats["acc"].append( metrics )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user