diff --git a/vall_e/config.py b/vall_e/config.py index 7816bd4..180260f 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -274,6 +274,9 @@ class ModelExperimentalSettings: # this should allow for "faster" training as each sample is trained entirely, but slower backwards (and possibly less stable training, maybe) monolithic_audio_encoder: bool = False # combines the prom/resp embeddings into one unit # this usually sounds bad, as the model can "extract" features from the prom separate from the ones in the resp + predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as + # * NAR-demask would semi-doubly train for AR + # * the model wouldn't also need to learn when to predict the token in place # these technically should be as hyperparameters # performs token dropout to compensate for errors diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index dead515..eb36b7a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -171,7 +171,7 @@ class AR_NAR(Base): resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 # only apply stop token for RVQ level 0 - if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7): + if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally): # append stop tokens for AR if task not in text_task: resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c680a3d..5088e61 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -581,9 +581,11 @@ class Base(nn.Module): ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True + predict_causally = self.config.experimental.predict_causally if self.config is not None else False monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False self.resp_parallel_training = resp_parallel_training + self.predict_causally = self.predict_causally n_tasks = self.config.tasks if self.config is not None else 8 n_langs = self.config.langs if self.config is not None else 2 @@ -1503,7 +1505,7 @@ class Base(nn.Module): return None, None # shift if causal - if causal or self.version >= 7: + if causal or self.predict_causally: l = self.causal_size logit = logit[..., :-l, :] # shift the target so that token n... sequence = sequence[..., l:] # ...predicts token n + 1