From 87db03dd93c541f91fb310d78afc2848dd2ec177 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 9 Oct 2023 22:03:58 -0500 Subject: [PATCH] trim the input prompt to 3 seconds when training NAR tasks (marked as experimental; the paper mentions doing so, but I don't know how much this would harm the retention heads) --- vall_e/data.py | 1 + vall_e/models/ar_nar.py | 8 +++++++- vall_e/models/nar.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 0772c13..5c1c6ee 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -293,6 +293,7 @@ class Dataset(_Dataset): prom_length = 0 if cfg.experimental: trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds] + #trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * 75)) else: trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ccd5836..6731261 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -10,6 +10,8 @@ from einops import rearrange from torch import Tensor from tqdm import trange +from ..emb.qnt import trim + class AR_NAR(Base): @property def causal(self): @@ -113,7 +115,11 @@ class AR_NAR(Base): targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l) - quant_levels.to(device=device) + + if cfg.experimental: + proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds + + #quant_levels.to(device=device) return super().forward( text_list=text_list, diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 46a984c..100bdfa 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -110,7 +110,7 @@ class NAR(Base): prev_list = [o[..., : l + 1] for o, l in zip(resps_list, quant_levels)] targ_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)] - quant_levels = quant_levels.to(device=device) + #quant_levels = quant_levels.to(device=device) logits = super().forward( text_list=text_list,