This commit is contained in:
mrq 2024-06-04 18:50:48 -05:00
parent ed3aeaf3a1
commit 6d5bd0156a
3 changed files with 11 additions and 7 deletions

View File

@ -47,7 +47,6 @@ def fold_inputs(
audio_tokens = 1024,
audio_rvq_levels = cfg.model.max_levels,
quant_levels = None,
experimental = False
):
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)

View File

@ -172,6 +172,9 @@ class Model(LlmArchClass):
if "do_sample" in kwargs:
kwargs.pop("do_sample")
if "min_length" in kwargs:
kwargs.pop("min_length")
return super().generate(*args, **kwargs)
def forward(
@ -359,7 +362,7 @@ def example_usage():
@torch.inference_mode()
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
engine.eval()
target_length = 0
batch_size = len(text_list)
resp_list = None
if cfg.model.interleave:
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
@ -368,12 +371,14 @@ def example_usage():
unfolded = unfold_outputs( output )
resp_list = unfolded["resp_list"]
else:
resp_list = [ [] for _ in range(len(text_list)) ]
resp_list = [ [] for _ in range(batch_size) ]
for l in range(cfg.model.max_levels):
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
quant_levels = [ l for _ in range(batch_size) ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
min_length = len(input_ids[0]) + 1
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels)
min_length = 1
for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 )
output = model.generate(
input_ids=input_ids,

View File

@ -135,7 +135,7 @@ def run_eval(engines, eval_name, dl):
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
min_length = 1
for batch in input_ids:
min_length = max( min_length, batch.shape[0] )
min_length = max( min_length, batch.shape[0] + 1 )
output = model.generate(
input_ids=input_ids,