another experimental flag
This commit is contained in:
parent
f593ee98fc
commit
8f5a3997bd
|
@ -274,6 +274,9 @@ class ModelExperimentalSettings:
|
||||||
# this should allow for "faster" training as each sample is trained entirely, but slower backwards (and possibly less stable training, maybe)
|
# this should allow for "faster" training as each sample is trained entirely, but slower backwards (and possibly less stable training, maybe)
|
||||||
monolithic_audio_encoder: bool = False # combines the prom/resp embeddings into one unit
|
monolithic_audio_encoder: bool = False # combines the prom/resp embeddings into one unit
|
||||||
# this usually sounds bad, as the model can "extract" features from the prom separate from the ones in the resp
|
# this usually sounds bad, as the model can "extract" features from the prom separate from the ones in the resp
|
||||||
|
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
|
||||||
|
# * NAR-demask would semi-doubly train for AR
|
||||||
|
# * the model wouldn't also need to learn when to predict the token in place
|
||||||
|
|
||||||
# these technically should be as hyperparameters
|
# these technically should be as hyperparameters
|
||||||
# performs token dropout to compensate for errors
|
# performs token dropout to compensate for errors
|
||||||
|
|
|
@ -171,7 +171,7 @@ class AR_NAR(Base):
|
||||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||||
|
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7):
|
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally):
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
if task not in text_task:
|
if task not in text_task:
|
||||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||||
|
|
|
@ -581,9 +581,11 @@ class Base(nn.Module):
|
||||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||||
|
|
||||||
resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True
|
resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True
|
||||||
|
predict_causally = self.config.experimental.predict_causally if self.config is not None else False
|
||||||
monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False
|
monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False
|
||||||
|
|
||||||
self.resp_parallel_training = resp_parallel_training
|
self.resp_parallel_training = resp_parallel_training
|
||||||
|
self.predict_causally = predict_causally
|
||||||
|
|
||||||
n_tasks = self.config.tasks if self.config is not None else 8
|
n_tasks = self.config.tasks if self.config is not None else 8
|
||||||
n_langs = self.config.langs if self.config is not None else 2
|
n_langs = self.config.langs if self.config is not None else 2
|
||||||
|
@ -1503,7 +1505,7 @@ class Base(nn.Module):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# shift if causal
|
# shift if causal
|
||||||
if causal or self.version >= 7:
|
if causal or self.predict_causally:
|
||||||
l = self.causal_size
|
l = self.causal_size
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||||
sequence = sequence[..., l:] # ...predicts token n + 1
|
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||||
|
|
|
@ -176,6 +176,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
|
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
|
|
||||||
|
"""
|
||||||
# to-do: eval for text tasks
|
# to-do: eval for text tasks
|
||||||
has_stt = False
|
has_stt = False
|
||||||
for i, task in enumerate( batch["task"] ):
|
for i, task in enumerate( batch["task"] ):
|
||||||
|
@ -192,6 +193,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
for i, _ in enumerate(batch["text"]):
|
for i, _ in enumerate(batch["text"]):
|
||||||
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
||||||
batch["resps"][i] = None
|
batch["resps"][i] = None
|
||||||
|
"""
|
||||||
|
|
||||||
processed += batch_size
|
processed += batch_size
|
||||||
for name in engines:
|
for name in engines:
|
||||||
|
@ -205,31 +207,37 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
training=False,
|
training=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if engine.hyper_config.experimental.hf:
|
if self.version >= 7:
|
||||||
resps_list = engine( **base_kwargs )
|
|
||||||
elif "len" in engine.hyper_config.capabilities:
|
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
max_steps = kwargs.pop("max_steps", 500)
|
# sample for NAR demask
|
||||||
|
if random.random() < cfg.model.experimental.masking_train_p:
|
||||||
if "denoise_start" in kwargs:
|
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
|
||||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
# inference
|
||||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
resps_list = engine( **kwargs )
|
||||||
else:
|
|
||||||
len_list = engine( max_steps=5, **kwargs )
|
|
||||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
|
||||||
|
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
|
||||||
resps_list = engine( **kwargs, len_list=len_list )
|
|
||||||
else:
|
else:
|
||||||
if "ar" in engine.hyper_config.capabilities:
|
if "len" in engine.hyper_config.capabilities:
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
resps_list = engine( **kwargs )
|
max_steps = kwargs.pop("max_steps", 500)
|
||||||
else:
|
|
||||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
if "denoise_start" in kwargs:
|
||||||
|
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||||
|
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||||
|
else:
|
||||||
|
len_list = engine( max_steps=5, **kwargs )
|
||||||
|
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||||
|
|
||||||
if "nar" in engine.hyper_config.capabilities:
|
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
resps_list = engine( **kwargs, resps_list=resps_list )
|
resps_list = engine( **kwargs, len_list=len_list )
|
||||||
|
else:
|
||||||
|
if "ar" in engine.hyper_config.capabilities:
|
||||||
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
|
resps_list = engine( **kwargs )
|
||||||
|
else:
|
||||||
|
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||||
|
|
||||||
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
|
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user