diff --git a/vall_e/config.py b/vall_e/config.py index fcb3a4d..1f5b4e7 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -118,6 +118,7 @@ class Dataset: hdf5_name: str = "data.h5" use_hdf5: bool = False + hdf5_flag: str = "a" validate: bool = True workers: int = 8 cache: bool = True @@ -467,7 +468,7 @@ try: # cached_property stopped working... if cfg.dataset.use_hdf5: try: - cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') # to-do, have an easy to set flag that determines if training or creating the dataset + cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset except Exception as e: print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) cfg.dataset.use_hdf5 = False diff --git a/vall_e/data.py b/vall_e/data.py index 0c72ea0..3d7d30f 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -10,7 +10,6 @@ import random import torch from .config import cfg -from .utils.sampler import Sampler from collections import defaultdict from functools import cache, cached_property @@ -168,12 +167,6 @@ class Dataset(_Dataset): self.durations[spkr_id] = duration else: self.durations[spkr_id] += duration - - if training and not cfg.distributed and self.sample_type == "path": - self.sampler = Sampler(self.paths, [cfg.get_spkr]) - else: - self.sampler = None - def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]): ret = defaultdict(list) for path in self.paths: @@ -267,10 +260,7 @@ class Dataset(_Dataset): spkr_id = self.spkr_symmap[spkr_name] path = random.choice([*set(self.paths_by_spkr_name[spkr_name])]) else: - if self.training and self.sampler is not None: - path = self.sampler.sample() - else: - path = self.paths[index] + path = self.paths[index] spkr_name = cfg.get_spkr(path) spkr_id = self.spkr_symmap[spkr_name] @@ -299,12 +289,18 @@ class Dataset(_Dataset): resps = self.sample_noise() resps = extend_audio(resps, proms.shape[0]) # something to prepend a sr token to the beginning of proms - elif task == "tse: + elif task == "tse": proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps other_speaker = self.sample_speaker(ignore=[spkr_name]) other_proms = self.sample_prompts(other_speaker, ignore="") proms = merge_audio(proms, other_proms) - # something to prepend a ns token to the beginning of proms + # something to prepend a tse token to the beginning of proms + """ + """ + # speech editing would require higher quality transcription data (phoneme level/word level) unfortunately + # as I need to get a good clean point to trim into + elif task == "cse": + elif task == "nse": """ @@ -595,9 +591,16 @@ if __name__ == "__main__": create_dataset_hdf5() train_dl, subtrain_dl, val_dl = create_train_val_dataloader() - print("Training DL:", next(iter(train_dl))) - print("Training DL:", next(iter(train_dl))) - print("Evaluation DL:", next(iter(subtrain_dl))) - print("Evaluation DL:", next(iter(subtrain_dl))) - print("Validation DL:", next(iter(val_dl))) - print("Validation DL:", next(iter(val_dl))) + + samples = { + "training": [ next(iter(train_dl)), next(iter(train_dl)) ], + "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ], + "validation": [ next(iter(val_dl)), next(iter(val_dl)) ], + } + + for k, v in samples.items(): + for i in range(len(v)): + del v[i]['proms'] + del v[i]['resps'] + print(f'{k}:', v) + diff --git a/vall_e/inference.py b/vall_e/inference.py index 7903da7..62371f6 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -35,6 +35,16 @@ class TTS(): cfg.load_yaml( config ) cfg.format() + + """ + if cfg.trainer.load_state_dict: + for model in cfg.models.get(): + path = cfg.ckpt_dir / model.full_name / "fp32.pth" + if model.name.startswith("ar"): + ar_ckpt = path + if model.name.startswith("nar"): + nar_ckpt = path + """ if ar_ckpt and nar_ckpt: self.ar_ckpt = ar_ckpt @@ -44,10 +54,16 @@ class TTS(): for name, model in models.items(): if name.startswith("ar"): self.ar = model.to(self.device, dtype=torch.float32) - self.ar.load_state_dict(torch.load(self.ar_ckpt)['module']) + state = torch.load(self.ar_ckpt) + if "module" in state: + state = state['module'] + self.ar.load_state_dict(state) elif name.startswith("nar"): self.nar = model.to(self.device, dtype=torch.float32) - self.nar.load_state_dict(torch.load(self.nar_ckpt)['module']) + state = torch.load(self.nar_ckpt) + if "module" in state: + state = state['module'] + self.nar.load_state_dict(state) else: self.load_models() diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py deleted file mode 100755 index 5db9606..0000000 --- a/vall_e/utils/sampler.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -A sampler that balances data by key_fns. - -MIT License - -Copyright (c) 2023 Zhe Niu - -niuzhe.nz@outlook.com -""" - -import random - - -class Sampler: - def __init__(self, l, key_fns): - self.tree = self._build(l, key_fns) - - def _build(self, l, key_fns) -> dict[dict, list]: - if not key_fns: - return l - - tree = {} - - key_fn, *key_fns = key_fns - - for x in l: - k = key_fn(x) - - if k in tree: - tree[k].append(x) - else: - tree[k] = [x] - - for k in tree: - tree[k] = self._build(tree[k], key_fns) - - return tree - - def _sample(self, tree: dict | list): - if isinstance(tree, list): - ret = random.choice(tree) - else: - key = random.choice([*tree.keys()]) - ret = self._sample(tree[key]) - return ret - - def sample(self): - return self._sample(self.tree) \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 366b12f..ca7af73 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -80,7 +80,10 @@ def load_engines(): if cfg.trainer.load_state_dict: load_path = cfg.ckpt_dir / name / "fp32.pth" - model.load_state_dict(torch.load(load_path)['module']) + state = torch.load(load_path) + if "module" in state: + state = state["module"] + model.load_state_dict(state) engines[name] = Engine( model=model,