made exporter make more sense

This commit is contained in:
mrq 2023-08-13 22:56:28 -05:00
parent d7deaf6def
commit 13571380be
5 changed files with 39 additions and 38 deletions

View File

@ -101,10 +101,10 @@ Training a VALL-E model is very, very meticulous. I've fiddled with a lot of """
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run:
```
python -m vall_e.export ./models/ yaml=./config/custom.yml
python -m vall_e.export yaml=./config/custom.yml
```
This will export the latest checkpoint.
This will export the latest checkpoints.
### Synthesis

View File

@ -17,6 +17,8 @@ from pathlib import Path
from omegaconf import OmegaConf
from .utils.distributed import world_size
@dataclass()
class _Config:
cfg_path: str | None = None
@ -401,7 +403,7 @@ class BitsAndBytes:
@dataclass()
class Config(_Config):
device: str = "cuda"
distributed: bool = False
#distributed: bool = False
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)
@ -415,6 +417,10 @@ class Config(_Config):
def sample_rate(self):
return 24_000
@property
def distributed(self):
return world_size() > 1
@cached_property
def get_spkr(self):
return eval(self.dataset.speaker_name_getter)
@ -447,20 +453,21 @@ try:
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
# cached_property stopped working...
if cfg.dataset.use_hdf5:
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a')
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
except Exception as e:
pass
# cached_property stopped working...
if cfg.dataset.use_hdf5:
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a')
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
if __name__ == "__main__":
print(cfg)

View File

@ -4,31 +4,22 @@ import torch
from .data import get_phone_symmap
from .train import load_engines
def load_models():
models = {}
engines = load_engines()
for name in engines:
model = engines[name].module.cpu()
model.phone_symmap = get_phone_symmap()
models[name] = model
return models
from .config import cfg
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("path")
#parser.add_argument("--yaml", type=Path, default=None)
args = parser.parse_args()
models = load_models()
for name in models:
model = models[name]
outpath = f'{args.path}/{name}.pt'
engines = load_engines()
for name, engine in engines.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
torch.save({
'module': model.state_dict()
"global_step": engine.global_step,
"micro_step": engine.micro_step,
'module': engine.module.to('cpu', dtype=torch.float32).state_dict(),
#'optimizer': engine.optimizer.state_dict(),
'symmap': get_phone_symmap(),
}, outpath)
print(f"Exported {name} to {outpath}")

View File

@ -19,6 +19,9 @@ class TTS():
self.input_sample_rate = 24000
self.output_sample_rate = 24000
if config:
cfg.load_yaml( config )
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
@ -33,7 +36,7 @@ class TTS():
self.nar = model.to(self.device, dtype=torch.float32)
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module'])
else:
self.load_models( config )
self.load_models()
self.symmap = get_phone_symmap()
self.ar.eval()
@ -41,10 +44,7 @@ class TTS():
self.loading = False
def load_models( self, config_path ):
if config_path:
cfg.load_yaml( config_path )
def load_models( self ):
print("Loading models...")
models = load_models()
print("Loaded models")

View File

@ -48,6 +48,9 @@ def local_rank():
def global_rank():
return int(os.getenv("RANK", 0))
def world_size():
return int(os.getenv("WORLD_SIZE", 1))
def is_local_leader():
return local_rank() == 0