fixes
This commit is contained in:
parent
58fb0a84db
commit
b072f9b96b
|
@ -158,10 +158,10 @@ class AR_NAR(Base):
|
||||||
index = i
|
index = i
|
||||||
return int(index)
|
return int(index)
|
||||||
|
|
||||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for _ in range(batch_size) ]
|
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||||
else:
|
else:
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for _ in range(batch_size) ]
|
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||||
|
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
|
|
||||||
|
@ -290,6 +290,7 @@ class AR_NAR(Base):
|
||||||
task_list = [ "len" if "len" in self.capabilities else "tts" for _ in range(batch_size) ]
|
task_list = [ "len" if "len" in self.capabilities else "tts" for _ in range(batch_size) ]
|
||||||
|
|
||||||
if "len" in self.capabilities:
|
if "len" in self.capabilities:
|
||||||
|
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||||
for n in trange(10, desc="AR"):
|
for n in trange(10, desc="AR"):
|
||||||
len_list = sequence_list
|
len_list = sequence_list
|
||||||
|
|
||||||
|
@ -309,6 +310,10 @@ class AR_NAR(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
||||||
|
# sanitize
|
||||||
|
for i, token in enumerate(r):
|
||||||
|
if token > 10:
|
||||||
|
r[i] = 0
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
|
|
|
@ -661,7 +661,7 @@ class Base(nn.Module):
|
||||||
if self.rvq_l_emb is not None:
|
if self.rvq_l_emb is not None:
|
||||||
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
||||||
quant_levels[i] = 0
|
quant_levels[i] = 0
|
||||||
# inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
||||||
if proms_list is not None:
|
if proms_list is not None:
|
||||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
|
|
||||||
|
@ -709,7 +709,7 @@ class Base(nn.Module):
|
||||||
embedding = self.resps_emb( torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), offset = 0 )
|
embedding = self.resps_emb( torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), offset = 0 )
|
||||||
else:
|
else:
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 else 1 )
|
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 or "len" in self.capabilities else 1 )
|
||||||
elif name == "len" and self.len_emb is not None:
|
elif name == "len" and self.len_emb is not None:
|
||||||
embedding = self.len_emb( input )
|
embedding = self.len_emb( input )
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -167,6 +167,10 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
for i, resp in enumerate( resps_list ):
|
for i, resp in enumerate( resps_list ):
|
||||||
resps_list[i] = torch.stack( resp ).t()
|
resps_list[i] = torch.stack( resp ).t()
|
||||||
|
else:
|
||||||
|
if "len" in engine.hyper_config.capabilities:
|
||||||
|
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=steps )
|
||||||
|
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list )
|
||||||
else:
|
else:
|
||||||
if "ar" in engine.hyper_config.capabilities:
|
if "ar" in engine.hyper_config.capabilities:
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user