fixes
This commit is contained in:
parent
58fb0a84db
commit
b072f9b96b
|
@ -158,10 +158,10 @@ class AR_NAR(Base):
|
|||
index = i
|
||||
return int(index)
|
||||
|
||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for _ in range(batch_size) ]
|
||||
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
else:
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for _ in range(batch_size) ]
|
||||
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
|
@ -290,6 +290,7 @@ class AR_NAR(Base):
|
|||
task_list = [ "len" if "len" in self.capabilities else "tts" for _ in range(batch_size) ]
|
||||
|
||||
if "len" in self.capabilities:
|
||||
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||
for n in trange(10, desc="AR"):
|
||||
len_list = sequence_list
|
||||
|
||||
|
@ -309,6 +310,10 @@ class AR_NAR(Base):
|
|||
)
|
||||
|
||||
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
||||
# sanitize
|
||||
for i, token in enumerate(r):
|
||||
if token > 10:
|
||||
r[i] = 0
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
|
|
|
@ -661,7 +661,7 @@ class Base(nn.Module):
|
|||
if self.rvq_l_emb is not None:
|
||||
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
||||
quant_levels[i] = 0
|
||||
# inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
||||
if proms_list is not None:
|
||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
|
||||
|
@ -709,7 +709,7 @@ class Base(nn.Module):
|
|||
embedding = self.resps_emb( torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), offset = 0 )
|
||||
else:
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 else 1 )
|
||||
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 or "len" in self.capabilities else 1 )
|
||||
elif name == "len" and self.len_emb is not None:
|
||||
embedding = self.len_emb( input )
|
||||
else:
|
||||
|
|
|
@ -168,13 +168,17 @@ def run_eval(engines, eval_name, dl):
|
|||
for i, resp in enumerate( resps_list ):
|
||||
resps_list[i] = torch.stack( resp ).t()
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
if "len" in engine.hyper_config.capabilities:
|
||||
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=steps )
|
||||
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list )
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user