fixes
This commit is contained in:
parent
ed3aeaf3a1
commit
6d5bd0156a
|
@ -47,7 +47,6 @@ def fold_inputs(
|
|||
audio_tokens = 1024,
|
||||
audio_rvq_levels = cfg.model.max_levels,
|
||||
quant_levels = None,
|
||||
experimental = False
|
||||
):
|
||||
def _create_mask(l, device):
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
|
|
@ -172,6 +172,9 @@ class Model(LlmArchClass):
|
|||
if "do_sample" in kwargs:
|
||||
kwargs.pop("do_sample")
|
||||
|
||||
if "min_length" in kwargs:
|
||||
kwargs.pop("min_length")
|
||||
|
||||
return super().generate(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
|
@ -359,7 +362,7 @@ def example_usage():
|
|||
@torch.inference_mode()
|
||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||
engine.eval()
|
||||
target_length = 0
|
||||
batch_size = len(text_list)
|
||||
resp_list = None
|
||||
if cfg.model.interleave:
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
|
||||
|
@ -368,12 +371,14 @@ def example_usage():
|
|||
unfolded = unfold_outputs( output )
|
||||
resp_list = unfolded["resp_list"]
|
||||
else:
|
||||
resp_list = [ [] for _ in range(len(text_list)) ]
|
||||
resp_list = [ [] for _ in range(batch_size) ]
|
||||
for l in range(cfg.model.max_levels):
|
||||
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
|
||||
quant_levels = [ l for _ in range(batch_size) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
|
||||
min_length = len(input_ids[0]) + 1
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels)
|
||||
min_length = 1
|
||||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
|
|
|
@ -135,7 +135,7 @@ def run_eval(engines, eval_name, dl):
|
|||
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
|
||||
min_length = 1
|
||||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] )
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
|
|
Loading…
Reference in New Issue
Block a user