diff --git a/vall_e/data.py b/vall_e/data.py index 0b2ab4a..f151bcb 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -47,7 +47,6 @@ def fold_inputs( audio_tokens = 1024, audio_rvq_levels = cfg.model.max_levels, quant_levels = None, - experimental = False ): def _create_mask(l, device): seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index e712184..ac499c0 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -172,6 +172,9 @@ class Model(LlmArchClass): if "do_sample" in kwargs: kwargs.pop("do_sample") + if "min_length" in kwargs: + kwargs.pop("min_length") + return super().generate(*args, **kwargs) def forward( @@ -359,7 +362,7 @@ def example_usage(): @torch.inference_mode() def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ): engine.eval() - target_length = 0 + batch_size = len(text_list) resp_list = None if cfg.model.interleave: input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list) @@ -368,12 +371,14 @@ def example_usage(): unfolded = unfold_outputs( output ) resp_list = unfolded["resp_list"] else: - resp_list = [ [] for _ in range(len(text_list)) ] + resp_list = [ [] for _ in range(batch_size) ] for l in range(cfg.model.max_levels): - quant_levels = [ [ l ] for _ in range(len(text_list)) ] + quant_levels = [ l for _ in range(batch_size) ] - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True) - min_length = len(input_ids[0]) + 1 + input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels) + min_length = 1 + for batch in input_ids: + min_length = max( min_length, batch.shape[0] + 1 ) output = model.generate( input_ids=input_ids, diff --git a/vall_e/train.py b/vall_e/train.py index 0087933..64a5a01 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -135,7 +135,7 @@ def run_eval(engines, eval_name, dl): input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True) min_length = 1 for batch in input_ids: - min_length = max( min_length, batch.shape[0] ) + min_length = max( min_length, batch.shape[0] + 1 ) output = model.generate( input_ids=input_ids,