This commit is contained in:
mrq 2024-06-08 16:01:34 -05:00
parent 58fb0a84db
commit b072f9b96b
3 changed files with 18 additions and 9 deletions

View File

@ -158,10 +158,10 @@ class AR_NAR(Base):
index = i
return int(index)
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for _ in range(batch_size) ]
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
else:
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for _ in range(batch_size) ]
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
@ -290,6 +290,7 @@ class AR_NAR(Base):
task_list = [ "len" if "len" in self.capabilities else "tts" for _ in range(batch_size) ]
if "len" in self.capabilities:
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
for n in trange(10, desc="AR"):
len_list = sequence_list
@ -309,6 +310,10 @@ class AR_NAR(Base):
)
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
# sanitize
for i, token in enumerate(r):
if token > 10:
r[i] = 0
# append tokens
for i, ri in enumerate(r):

View File

@ -661,7 +661,7 @@ class Base(nn.Module):
if self.rvq_l_emb is not None:
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
quant_levels[i] = 0
# inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
if proms_list is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
@ -709,7 +709,7 @@ class Base(nn.Module):
embedding = self.resps_emb( torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), offset = 0 )
else:
# get RVQ level 0, or up to targetted RVQ level inference
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 else 1 )
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 or "len" in self.capabilities else 1 )
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )
else:

View File

@ -167,6 +167,10 @@ def run_eval(engines, eval_name, dl):
for i, resp in enumerate( resps_list ):
resps_list[i] = torch.stack( resp ).t()
else:
if "len" in engine.hyper_config.capabilities:
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=steps )
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list )
else:
if "ar" in engine.hyper_config.capabilities:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)