another experimental flag

This commit is contained in:
mrq 2025-02-24 13:43:26 -06:00
parent f593ee98fc
commit f93fbf0d99
3 changed files with 7 additions and 2 deletions

View File

@ -274,6 +274,9 @@ class ModelExperimentalSettings:
# this should allow for "faster" training as each sample is trained entirely, but slower backwards (and possibly less stable training, maybe)
monolithic_audio_encoder: bool = False # combines the prom/resp embeddings into one unit
# this usually sounds bad, as the model can "extract" features from the prom separate from the ones in the resp
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
# * NAR-demask would semi-doubly train for AR
# * the model wouldn't also need to learn when to predict the token in place
# these technically should be as hyperparameters
# performs token dropout to compensate for errors

View File

@ -171,7 +171,7 @@ class AR_NAR(Base):
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7):
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally):
# append stop tokens for AR
if task not in text_task:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])

View File

@ -581,9 +581,11 @@ class Base(nn.Module):
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True
predict_causally = self.config.experimental.predict_causally if self.config is not None else False
monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False
self.resp_parallel_training = resp_parallel_training
self.predict_causally = self.predict_causally
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
@ -1503,7 +1505,7 @@ class Base(nn.Module):
return None, None
# shift if causal
if causal or self.version >= 7:
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
sequence = sequence[..., l:] # ...predicts token n + 1