might just be better to explicitly define prompt duration ranges, especially under a "train small contexts then increase it" training paradigm
This commit is contained in:
parent
bd0a36ba8d
commit
4d93a16ef7
|
@ -144,11 +144,13 @@ class Dataset:
|
||||||
|
|
||||||
phones_range: list[int] = field(default_factory=lambda: [4, 256])
|
phones_range: list[int] = field(default_factory=lambda: [4, 256])
|
||||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0])
|
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0])
|
||||||
|
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0])
|
||||||
min_utterances: int = 2
|
min_utterances: int = 2
|
||||||
|
|
||||||
random_utterance: float = 1.0
|
random_utterance: float = 1.0
|
||||||
max_prompts: int = 3
|
max_prompts: int = 3
|
||||||
prompt_duration: float = 3.0
|
|
||||||
|
prompt_duration: float = 0.0 # legacy
|
||||||
|
|
||||||
max_resps: int = 1
|
max_resps: int = 1
|
||||||
p_resp_append: float = 1.0
|
p_resp_append: float = 1.0
|
||||||
|
@ -676,6 +678,9 @@ class Config(_Config):
|
||||||
if self.hyperparameters.scheduler == "":
|
if self.hyperparameters.scheduler == "":
|
||||||
self.hyperparameters.torch_scheduler = True
|
self.hyperparameters.torch_scheduler = True
|
||||||
|
|
||||||
|
if self.dataset.prompt_duration != 0:
|
||||||
|
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
|
||||||
|
|
||||||
# Preserves the old behavior
|
# Preserves the old behavior
|
||||||
class NaiveTokenizer:
|
class NaiveTokenizer:
|
||||||
def get_vocab( self ):
|
def get_vocab( self ):
|
||||||
|
|
|
@ -368,12 +368,8 @@ class Dataset(_Dataset):
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# shuffle it up a bit
|
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
if cfg.experimental and False:
|
trim_length = random.randint(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second
|
||||||
trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
|
|
||||||
else:
|
|
||||||
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second)
|
|
||||||
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.max_prompts):
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
|
@ -388,7 +384,7 @@ class Dataset(_Dataset):
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path)
|
qnt = _load_quants(path)
|
||||||
|
|
||||||
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
|
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||||
qnt = trim( qnt, trim_length )
|
qnt = trim( qnt, trim_length )
|
||||||
|
|
||||||
prom_list.append(qnt)
|
prom_list.append(qnt)
|
||||||
|
@ -401,7 +397,7 @@ class Dataset(_Dataset):
|
||||||
# as you technically can't just append encodec sequences together like this without issues
|
# as you technically can't just append encodec sequences together like this without issues
|
||||||
prom = torch.cat(prom_list)
|
prom = torch.cat(prom_list)
|
||||||
|
|
||||||
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
|
if 0 < trim_length and trim_length < prom.shape[0]:
|
||||||
prom = trim( prom, trim_length )
|
prom = trim( prom, trim_length )
|
||||||
|
|
||||||
return prom
|
return prom
|
||||||
|
@ -474,7 +470,7 @@ class Dataset(_Dataset):
|
||||||
resps = torch.concat([ resps, qnt ])
|
resps = torch.concat([ resps, qnt ])
|
||||||
|
|
||||||
task = "tts"
|
task = "tts"
|
||||||
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second)
|
trim_length = random.randint(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user