trim the input prompt to 3 seconds when training NAR tasks (marked as experimental; the paper mentions doing so, but I don't know how much this would harm the retention heads)
This commit is contained in:
parent
893a610fad
commit
87db03dd93
|
@ -293,6 +293,7 @@ class Dataset(_Dataset):
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
if cfg.experimental:
|
if cfg.experimental:
|
||||||
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||||
|
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * 75))
|
||||||
else:
|
else:
|
||||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,8 @@ from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
from ..emb.qnt import trim
|
||||||
|
|
||||||
class AR_NAR(Base):
|
class AR_NAR(Base):
|
||||||
@property
|
@property
|
||||||
def causal(self):
|
def causal(self):
|
||||||
|
@ -113,7 +115,11 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
||||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l)
|
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l)
|
||||||
quant_levels.to(device=device)
|
|
||||||
|
if cfg.experimental:
|
||||||
|
proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
||||||
|
|
||||||
|
#quant_levels.to(device=device)
|
||||||
|
|
||||||
return super().forward(
|
return super().forward(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
|
|
@ -110,7 +110,7 @@ class NAR(Base):
|
||||||
prev_list = [o[..., : l + 1] for o, l in zip(resps_list, quant_levels)]
|
prev_list = [o[..., : l + 1] for o, l in zip(resps_list, quant_levels)]
|
||||||
targ_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)]
|
targ_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)]
|
||||||
|
|
||||||
quant_levels = quant_levels.to(device=device)
|
#quant_levels = quant_levels.to(device=device)
|
||||||
|
|
||||||
logits = super().forward(
|
logits = super().forward(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user