diff --git a/README.md b/README.md index 6ccec8d..f4123c6 100755 --- a/README.md +++ b/README.md @@ -101,10 +101,10 @@ Training a VALL-E model is very, very meticulous. I've fiddled with a lot of """ Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run: ``` -python -m vall_e.export ./models/ yaml=./config/custom.yml +python -m vall_e.export yaml=./config/custom.yml ``` -This will export the latest checkpoint. +This will export the latest checkpoints. ### Synthesis diff --git a/vall_e/config.py b/vall_e/config.py index 442b5f3..ff34c50 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -17,6 +17,8 @@ from pathlib import Path from omegaconf import OmegaConf +from .utils.distributed import world_size + @dataclass() class _Config: cfg_path: str | None = None @@ -401,7 +403,7 @@ class BitsAndBytes: @dataclass() class Config(_Config): device: str = "cuda" - distributed: bool = False + #distributed: bool = False dataset: Dataset = field(default_factory=lambda: Dataset) models: Models = field(default_factory=lambda: Models) @@ -415,6 +417,10 @@ class Config(_Config): def sample_rate(self): return 24_000 + @property + def distributed(self): + return world_size() > 1 + @cached_property def get_spkr(self): return eval(self.dataset.speaker_name_getter) @@ -447,20 +453,21 @@ try: cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed) + + # cached_property stopped working... + if cfg.dataset.use_hdf5: + try: + cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') + except Exception as e: + print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) + cfg.dataset.use_hdf5 = False + + if not cfg.dataset.use_hdf5: + cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] + cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] except Exception as e: pass -# cached_property stopped working... -if cfg.dataset.use_hdf5: - try: - cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') - except Exception as e: - print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) - cfg.dataset.use_hdf5 = False - -if not cfg.dataset.use_hdf5: - cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] - cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] if __name__ == "__main__": print(cfg) diff --git a/vall_e/export.py b/vall_e/export.py index 093f143..8f51871 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -4,31 +4,22 @@ import torch from .data import get_phone_symmap from .train import load_engines - -def load_models(): - models = {} - engines = load_engines() - for name in engines: - model = engines[name].module.cpu() - - model.phone_symmap = get_phone_symmap() - - models[name] = model - - return models +from .config import cfg def main(): parser = argparse.ArgumentParser("Save trained model to path.") - parser.add_argument("path") + #parser.add_argument("--yaml", type=Path, default=None) args = parser.parse_args() - models = load_models() - for name in models: - model = models[name] - - outpath = f'{args.path}/{name}.pt' + engines = load_engines() + for name, engine in engines.items(): + outpath = cfg.ckpt_dir / name / "fp32.pth" torch.save({ - 'module': model.state_dict() + "global_step": engine.global_step, + "micro_step": engine.micro_step, + 'module': engine.module.to('cpu', dtype=torch.float32).state_dict(), + #'optimizer': engine.optimizer.state_dict(), + 'symmap': get_phone_symmap(), }, outpath) print(f"Exported {name} to {outpath}") diff --git a/vall_e/inference.py b/vall_e/inference.py index e6c8193..9983934 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -19,6 +19,9 @@ class TTS(): self.input_sample_rate = 24000 self.output_sample_rate = 24000 + + if config: + cfg.load_yaml( config ) if ar_ckpt and nar_ckpt: self.ar_ckpt = ar_ckpt @@ -33,7 +36,7 @@ class TTS(): self.nar = model.to(self.device, dtype=torch.float32) self.nar.load_state_dict(torch.load(self.nar_ckpt)['module']) else: - self.load_models( config ) + self.load_models() self.symmap = get_phone_symmap() self.ar.eval() @@ -41,10 +44,7 @@ class TTS(): self.loading = False - def load_models( self, config_path ): - if config_path: - cfg.load_yaml( config_path ) - + def load_models( self ): print("Loading models...") models = load_models() print("Loaded models") diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index e80b0dd..e43c7e5 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -48,6 +48,9 @@ def local_rank(): def global_rank(): return int(os.getenv("RANK", 0)) +def world_size(): + return int(os.getenv("WORLD_SIZE", 1)) + def is_local_leader(): return local_rank() == 0