setting up for allowing training for a partial amount of the speechx tasks (do NOT try this at home yet without a proper model, as performance is predecated on having a solid base vall-e model for the tasks
This commit is contained in:
parent
ae9d38aa31
commit
8f42c578c9
|
@ -20,6 +20,7 @@ dataset:
|
||||||
prompt_duration: 3.0
|
prompt_duration: 3.0
|
||||||
|
|
||||||
sample_type: speaker
|
sample_type: speaker
|
||||||
|
tasks_list: ["tts"] # do NOT change this until you're ready to train for SpeechX tasks # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"]
|
||||||
|
|
||||||
models:
|
models:
|
||||||
_models:
|
_models:
|
||||||
|
@ -102,6 +103,7 @@ trainer:
|
||||||
|
|
||||||
inference:
|
inference:
|
||||||
use_vocos: True
|
use_vocos: True
|
||||||
|
normalize: False # do NOT change this unless you know exactly what you are doing.
|
||||||
|
|
||||||
bitsandbytes:
|
bitsandbytes:
|
||||||
enabled: false
|
enabled: false
|
||||||
|
|
|
@ -133,6 +133,8 @@ class Dataset:
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
|
|
||||||
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
name: str = ""
|
name: str = ""
|
||||||
|
@ -475,6 +477,7 @@ try:
|
||||||
if not cfg.dataset.use_hdf5:
|
if not cfg.dataset.use_hdf5:
|
||||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||||
|
|
||||||
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
|
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -204,7 +204,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def tasks(self):
|
def tasks(self):
|
||||||
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
|
||||||
|
|
||||||
def _get_phone_symmap(self):
|
def _get_phone_symmap(self):
|
||||||
return get_phone_symmap()
|
return get_phone_symmap()
|
||||||
|
@ -216,16 +216,22 @@ class Dataset(_Dataset):
|
||||||
return get_task_symmap()
|
return get_task_symmap()
|
||||||
|
|
||||||
def get_task_token( self, token ):
|
def get_task_token( self, token ):
|
||||||
|
if not hasattr(self, "task_symmap"):
|
||||||
|
self.task_symmap = self._get_task_symmap()
|
||||||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
|
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
|
||||||
|
|
||||||
def sample_noise(self):
|
def sample_noise(self):
|
||||||
paths = []
|
paths = []
|
||||||
print(cfg.dataset.noise)
|
|
||||||
for data_dir in cfg.dataset.noise:
|
for data_dir in cfg.dataset.noise:
|
||||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
paths.extend(data_dir.rglob("*.qnt.pt"))
|
||||||
|
|
||||||
path = random.choice(paths)
|
path = random.choice(paths)
|
||||||
return _load_quants(path)
|
|
||||||
|
if False and cfg.dataset.use_hdf5:
|
||||||
|
key = f'/noise/{_get_hdf5_path(path)}'
|
||||||
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||||
|
else:
|
||||||
|
qnt = _load_quants(path)
|
||||||
|
return qnt
|
||||||
|
|
||||||
def sample_speakers(self, ignore=[]):
|
def sample_speakers(self, ignore=[]):
|
||||||
choices = set(self.spkrs) - set(ignore)
|
choices = set(self.spkrs) - set(ignore)
|
||||||
|
@ -305,7 +311,7 @@ class Dataset(_Dataset):
|
||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||||
# create the input prompt by merging the target audio with the noise
|
# create the input prompt by merging the target audio with the noise
|
||||||
proms = merge_audio(resps, noise, scale=[1, 0.125])
|
proms = merge_audio(resps, noise, scale=[1, 0.125], device="cpu")
|
||||||
# set the target to just be the noise if <sr>
|
# set the target to just be the noise if <sr>
|
||||||
if task == "sr":
|
if task == "sr":
|
||||||
resps = noise
|
resps = noise
|
||||||
|
@ -325,10 +331,10 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||||
if other_proms.shape[0] == smallest_size:
|
if other_proms.shape[0] == smallest_size:
|
||||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)] )
|
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
||||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||||
else:
|
else:
|
||||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)] )
|
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
||||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||||
|
|
||||||
# stitch together the promps
|
# stitch together the promps
|
||||||
|
|
Loading…
Reference in New Issue
Block a user