fixes
This commit is contained in:
parent
ed3aeaf3a1
commit
6d5bd0156a
|
@ -47,7 +47,6 @@ def fold_inputs(
|
||||||
audio_tokens = 1024,
|
audio_tokens = 1024,
|
||||||
audio_rvq_levels = cfg.model.max_levels,
|
audio_rvq_levels = cfg.model.max_levels,
|
||||||
quant_levels = None,
|
quant_levels = None,
|
||||||
experimental = False
|
|
||||||
):
|
):
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||||
|
|
|
@ -172,6 +172,9 @@ class Model(LlmArchClass):
|
||||||
if "do_sample" in kwargs:
|
if "do_sample" in kwargs:
|
||||||
kwargs.pop("do_sample")
|
kwargs.pop("do_sample")
|
||||||
|
|
||||||
|
if "min_length" in kwargs:
|
||||||
|
kwargs.pop("min_length")
|
||||||
|
|
||||||
return super().generate(*args, **kwargs)
|
return super().generate(*args, **kwargs)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -359,7 +362,7 @@ def example_usage():
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
target_length = 0
|
batch_size = len(text_list)
|
||||||
resp_list = None
|
resp_list = None
|
||||||
if cfg.model.interleave:
|
if cfg.model.interleave:
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
|
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
|
||||||
|
@ -368,12 +371,14 @@ def example_usage():
|
||||||
unfolded = unfold_outputs( output )
|
unfolded = unfold_outputs( output )
|
||||||
resp_list = unfolded["resp_list"]
|
resp_list = unfolded["resp_list"]
|
||||||
else:
|
else:
|
||||||
resp_list = [ [] for _ in range(len(text_list)) ]
|
resp_list = [ [] for _ in range(batch_size) ]
|
||||||
for l in range(cfg.model.max_levels):
|
for l in range(cfg.model.max_levels):
|
||||||
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
|
quant_levels = [ l for _ in range(batch_size) ]
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
|
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels)
|
||||||
min_length = len(input_ids[0]) + 1
|
min_length = 1
|
||||||
|
for batch in input_ids:
|
||||||
|
min_length = max( min_length, batch.shape[0] + 1 )
|
||||||
|
|
||||||
output = model.generate(
|
output = model.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|
|
@ -135,7 +135,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
|
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
|
||||||
min_length = 1
|
min_length = 1
|
||||||
for batch in input_ids:
|
for batch in input_ids:
|
||||||
min_length = max( min_length, batch.shape[0] )
|
min_length = max( min_length, batch.shape[0] + 1 )
|
||||||
|
|
||||||
output = model.generate(
|
output = model.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user