made exporter make more sense
This commit is contained in:
parent
d7deaf6def
commit
13571380be
|
@ -101,10 +101,10 @@ Training a VALL-E model is very, very meticulous. I've fiddled with a lot of """
|
|||
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run:
|
||||
|
||||
```
|
||||
python -m vall_e.export ./models/ yaml=./config/custom.yml
|
||||
python -m vall_e.export yaml=./config/custom.yml
|
||||
```
|
||||
|
||||
This will export the latest checkpoint.
|
||||
This will export the latest checkpoints.
|
||||
|
||||
### Synthesis
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@ from pathlib import Path
|
|||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from .utils.distributed import world_size
|
||||
|
||||
@dataclass()
|
||||
class _Config:
|
||||
cfg_path: str | None = None
|
||||
|
@ -401,7 +403,7 @@ class BitsAndBytes:
|
|||
@dataclass()
|
||||
class Config(_Config):
|
||||
device: str = "cuda"
|
||||
distributed: bool = False
|
||||
#distributed: bool = False
|
||||
|
||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||
models: Models = field(default_factory=lambda: Models)
|
||||
|
@ -415,6 +417,10 @@ class Config(_Config):
|
|||
def sample_rate(self):
|
||||
return 24_000
|
||||
|
||||
@property
|
||||
def distributed(self):
|
||||
return world_size() > 1
|
||||
|
||||
@cached_property
|
||||
def get_spkr(self):
|
||||
return eval(self.dataset.speaker_name_getter)
|
||||
|
@ -447,20 +453,21 @@ try:
|
|||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
||||
|
||||
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
|
||||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
try:
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a')
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
||||
if not cfg.dataset.use_hdf5:
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
try:
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a')
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
||||
if not cfg.dataset.use_hdf5:
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(cfg)
|
||||
|
|
|
@ -4,31 +4,22 @@ import torch
|
|||
|
||||
from .data import get_phone_symmap
|
||||
from .train import load_engines
|
||||
|
||||
def load_models():
|
||||
models = {}
|
||||
engines = load_engines()
|
||||
for name in engines:
|
||||
model = engines[name].module.cpu()
|
||||
|
||||
model.phone_symmap = get_phone_symmap()
|
||||
|
||||
models[name] = model
|
||||
|
||||
return models
|
||||
from .config import cfg
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("path")
|
||||
#parser.add_argument("--yaml", type=Path, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
models = load_models()
|
||||
for name in models:
|
||||
model = models[name]
|
||||
|
||||
outpath = f'{args.path}/{name}.pt'
|
||||
engines = load_engines()
|
||||
for name, engine in engines.items():
|
||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||
torch.save({
|
||||
'module': model.state_dict()
|
||||
"global_step": engine.global_step,
|
||||
"micro_step": engine.micro_step,
|
||||
'module': engine.module.to('cpu', dtype=torch.float32).state_dict(),
|
||||
#'optimizer': engine.optimizer.state_dict(),
|
||||
'symmap': get_phone_symmap(),
|
||||
}, outpath)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
|
||||
|
|
|
@ -19,6 +19,9 @@ class TTS():
|
|||
|
||||
self.input_sample_rate = 24000
|
||||
self.output_sample_rate = 24000
|
||||
|
||||
if config:
|
||||
cfg.load_yaml( config )
|
||||
|
||||
if ar_ckpt and nar_ckpt:
|
||||
self.ar_ckpt = ar_ckpt
|
||||
|
@ -33,7 +36,7 @@ class TTS():
|
|||
self.nar = model.to(self.device, dtype=torch.float32)
|
||||
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module'])
|
||||
else:
|
||||
self.load_models( config )
|
||||
self.load_models()
|
||||
|
||||
self.symmap = get_phone_symmap()
|
||||
self.ar.eval()
|
||||
|
@ -41,10 +44,7 @@ class TTS():
|
|||
|
||||
self.loading = False
|
||||
|
||||
def load_models( self, config_path ):
|
||||
if config_path:
|
||||
cfg.load_yaml( config_path )
|
||||
|
||||
def load_models( self ):
|
||||
print("Loading models...")
|
||||
models = load_models()
|
||||
print("Loaded models")
|
||||
|
|
|
@ -48,6 +48,9 @@ def local_rank():
|
|||
def global_rank():
|
||||
return int(os.getenv("RANK", 0))
|
||||
|
||||
def world_size():
|
||||
return int(os.getenv("WORLD_SIZE", 1))
|
||||
|
||||
|
||||
def is_local_leader():
|
||||
return local_rank() == 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user