From 48490757da4b949724bcf29435e432a73006d5a3 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 10 Nov 2024 20:37:50 -0600 Subject: [PATCH] fixes --- vall_e/emb/similar.py | 2 +- vall_e/engines/__init__.py | 2 +- vall_e/models/ar_nar.py | 31 +++++++++++++++++-------------- vall_e/train.py | 5 +++-- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index f4f3ebb..7b41468 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -282,7 +282,7 @@ def main(): dtype=args.dtype, amp=args.amp, - verbose=False, + verbose=True, ) if not similarities: diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 0a041e4..5a51496 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -188,7 +188,7 @@ def load_engines(training=True, **model_kwargs): # resize modules if I'm doing experiments and can't be assed to manually trim things if cfg.trainer.resize_modules: - uses_stop_token = 1 if model.causal_size > 0 else 0 + uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0 keys = [ ("text_emb.weight", model.config.text_tokens ), ("tasks_emb.weight", model.config.tasks ), diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 1e53b1e..a9dd77f 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -273,6 +273,7 @@ class AR_NAR(Base): return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim) _super = super() + # to-do: allow for batch processing (it should probably work batched anyways) def demask_sampling( batch_index, seq_len ): # overrides max_steps = 10 @@ -321,17 +322,17 @@ class AR_NAR(Base): # setup inputs inputs = _super.inputs( - text_list=text_list, - proms_list=proms_list, + text_list=[ text_list[batch_index] ] if text_list else None, + proms_list=[ proms_list[batch_index] ] if proms_list else None, resps_list=[ input_ids ], - lang_list=lang_list, - tone_list=tone_list, + lang_list=[ lang_list[batch_index] ] if lang_list else None, + tone_list=[ tone_list[batch_index] ] if tone_list else None, time_list=[ timestep ], - quant_levels=quant_levels, + quant_levels=[ quant_levels[batch_index] ] if quant_levels else None, ) output = _super.forward( inputs=inputs, - quant_levels=quant_levels, + quant_levels=[ quant_levels[batch_index] ] if quant_levels else None, #layer_skip_variables=sampling_layer_skip_variables, ) @@ -342,24 +343,24 @@ class AR_NAR(Base): text_list=[ null_text ], proms_list=[ null_prom ], resps_list=[ input_ids ], - lang_list=lang_list, - tone_list=tone_list, + lang_list=[ lang_list[batch_index] ] if lang_list else None, + tone_list=[ tone_list[batch_index] ] if tone_list else None, time_list=[ timestep ], - quant_levels=quant_levels, + quant_levels=[ quant_levels[batch_index] ] if quant_levels else None, ) null_output = _super.forward( inputs=null_inputs, - quant_levels=quant_levels, + quant_levels=[ quant_levels[batch_index] ] if quant_levels else None, #layer_skip_variables=sampling_layer_skip_variables, ) - for logit, null_logits in zip(output.logits, null_output.logits): - logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength + for logit, null_logit in zip(output.logits, null_output.logits): + logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength # sample with sampler settings filtered_sampled = _super.sample( logits=logits, prev_list=prev_list, - quant_levels=quant_levels, + quant_levels=[ quant_levels[batch_index] ] if quant_levels else None, temperature=temperature, min_temperature=sampling_min_temperature, @@ -375,7 +376,7 @@ class AR_NAR(Base): unfiltered_sampled = _super.sample( logits=logits, prev_list=prev_list, - quant_levels=quant_levels, + quant_levels=[ quant_levels[batch_index] ] if quant_levels else None, temperature=0.0, ) # update previous list of tokens @@ -732,8 +733,10 @@ class AR_NAR(Base): filename = "metrics" if sampling_entropix: filename += f'[entropix]' + """ if sampling_layer_skip_exit_layer >= 0: filename += f'[{sampling_layer_skip_exit_layer+1}]' + """ plot_sample_metrics( metrics, filename=f'{filename}.png' ) diff --git a/vall_e/train.py b/vall_e/train.py index 6446f5f..69b9307 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -148,13 +148,14 @@ def run_eval(engines, eval_name, dl, args=None): elif "len" in engine.hyper_config.capabilities: kwargs = base_kwargs | cfg.evaluation.ar_kwargs max_steps = kwargs.pop("max_steps", 500) - len_list = engine( max_steps=5, **kwargs ) - len_list = [ min( l, max_steps ) for l in len_list ] if True: len_list = [ resp.shape[0] for resp in batch["resps"] ] kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ] kwargs["denoise_start"] = 0.5 + else: + len_list = engine( max_steps=5, **kwargs ) + len_list = [ min( l, max_steps ) for l in len_list ] kwargs = base_kwargs | cfg.evaluation.nar_kwargs resps_list = engine( **kwargs, len_list=len_list )