un 'experimental' the better target sequence preparation
This commit is contained in:
parent
9a6040383e
commit
ed54f4ebec
|
@ -127,15 +127,18 @@ class AR_NAR(Base):
|
||||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
||||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
|
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
|
||||||
|
|
||||||
|
"""
|
||||||
if cfg.experimental:
|
if cfg.experimental:
|
||||||
proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
||||||
# append stop tokens for AR
|
"""
|
||||||
for i in range(batch_size):
|
|
||||||
if quant_levels[i] > 0:
|
# append stop tokens for AR
|
||||||
continue
|
for i in range(batch_size):
|
||||||
|
if quant_levels[i] > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
|
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
|
||||||
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||||
|
|
||||||
return super().forward(
|
return super().forward(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user