diff --git a/vall_e/config.py b/vall_e/config.py
index 2e6204e..8fdd591 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -136,6 +136,9 @@ class Dataset:
random_utterance: float = 1.0
max_prompts: int = 3
prompt_duration: float = 3.0
+
+ max_resps: int = 1
+ p_resp_append: float = 1.0
sample_type: str = "path" # path | speaker
diff --git a/vall_e/data.py b/vall_e/data.py
index 08244c9..15833a6 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -319,6 +319,8 @@ class Dataset(_Dataset):
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
break
+ # might be better to decode => concat waveforms with silence in between => reencode
+ # as you technically can't just append encodec sequences together like this without issues
prom = torch.cat(prom_list)
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
@@ -348,11 +350,41 @@ class Dataset(_Dataset):
else:
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
resps = _load_quants(path)
+
+ # append additional prompts in an attempt to artifically increase lengths / offer new data
+ if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
+ choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
+
+ if len(choices) > 0:
+ for _ in range( cfg.dataset.max_resps - 1 ):
+ sampled_path = random.choice(choices)
+ choices = [*(set(choices) - {sampled_path})]
+ if cfg.dataset.use_hdf5:
+ key = _get_hdf5_path(path)
+ txt = cfg.hdf5[key]["text"][:]
+ qnt = cfg.hdf5[key]["audio"][:, :]
+
+ txt = np.array( _cleanup_phones( txt, targets=[ self.phone_symmap[" "] ] ) )
+
+ txt = torch.from_numpy(txt).to(self.text_dtype)
+ qnt = torch.from_numpy(qnt).to(torch.int16)
+ else:
+ txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
+ qnt = _load_quants(sampled_path)
+
+ # [original text] [new text]
+ # removes the original text's , includes a space, and remove the new text's
+ text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
+
+ # might be better to decode => concat waveforms with silence in between => reencode
+ # as you technically can't just append encodec sequences together like this without issues
+ resps = torch.concat([ resps, qnt ])
task = "tts"
trim_length = int(cfg.dataset.prompt_duration * 75)
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
+
# Disabled until I swap over to a better method
"""
task = random.choice(self.tasks)