setting up for allowing training for a partial amount of the speechx tasks (do NOT try this at home yet without a proper model, as performance is predecated on having a solid base vall-e model for the tasks

This commit is contained in:
mrq 2023-08-19 00:16:08 -05:00
parent ae9d38aa31
commit 8f42c578c9
3 changed files with 20 additions and 9 deletions

View File

@ -20,6 +20,7 @@ dataset:
prompt_duration: 3.0 prompt_duration: 3.0
sample_type: speaker sample_type: speaker
tasks_list: ["tts"] # do NOT change this until you're ready to train for SpeechX tasks # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"]
models: models:
_models: _models:
@ -102,6 +103,7 @@ trainer:
inference: inference:
use_vocos: True use_vocos: True
normalize: False # do NOT change this unless you know exactly what you are doing.
bitsandbytes: bitsandbytes:
enabled: false enabled: false

View File

@ -133,6 +133,8 @@ class Dataset:
sample_type: str = "path" # path | speaker sample_type: str = "path" # path | speaker
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
@dataclass() @dataclass()
class Model: class Model:
name: str = "" name: str = ""
@ -475,6 +477,7 @@ try:
if not cfg.dataset.use_hdf5: if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ] cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
except Exception as e: except Exception as e:
pass pass

View File

@ -204,7 +204,7 @@ class Dataset(_Dataset):
@cached_property @cached_property
def tasks(self): def tasks(self):
return ["tts"] # "ns", "sr", "tse", "cse", "nse" return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
def _get_phone_symmap(self): def _get_phone_symmap(self):
return get_phone_symmap() return get_phone_symmap()
@ -216,16 +216,22 @@ class Dataset(_Dataset):
return get_task_symmap() return get_task_symmap()
def get_task_token( self, token ): def get_task_token( self, token ):
if not hasattr(self, "task_symmap"):
self.task_symmap = self._get_task_symmap()
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16) return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
def sample_noise(self): def sample_noise(self):
paths = [] paths = []
print(cfg.dataset.noise)
for data_dir in cfg.dataset.noise: for data_dir in cfg.dataset.noise:
paths.extend(data_dir.rglob("*.qnt.pt")) paths.extend(data_dir.rglob("*.qnt.pt"))
path = random.choice(paths) path = random.choice(paths)
return _load_quants(path)
if False and cfg.dataset.use_hdf5:
key = f'/noise/{_get_hdf5_path(path)}'
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
qnt = _load_quants(path)
return qnt
def sample_speakers(self, ignore=[]): def sample_speakers(self, ignore=[]):
choices = set(self.spkrs) - set(ignore) choices = set(self.spkrs) - set(ignore)
@ -305,7 +311,7 @@ class Dataset(_Dataset):
# extend the noise to fill the target audio # extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0]) noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise # create the input prompt by merging the target audio with the noise
proms = merge_audio(resps, noise, scale=[1, 0.125]) proms = merge_audio(resps, noise, scale=[1, 0.125], device="cpu")
# set the target to just be the noise if <sr> # set the target to just be the noise if <sr>
if task == "sr": if task == "sr":
resps = noise resps = noise
@ -325,10 +331,10 @@ class Dataset(_Dataset):
smallest_size = min(resps.shape[0], other_proms.shape[0]) smallest_size = min(resps.shape[0], other_proms.shape[0])
if other_proms.shape[0] == smallest_size: if other_proms.shape[0] == smallest_size:
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)] ) noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] ) noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
else: else:
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)] ) noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] ) noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
# stitch together the promps # stitch together the promps