diff --git a/vall_e/config.py b/vall_e/config.py index 320adbb..fe92beb 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -442,6 +442,11 @@ if cfg.dataset.use_hdf5: cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a') except Exception as e: print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) + cfg.dataset.use_hdf5 = False + +if not cfg.dataset.use_hdf5: + cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] + cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] if __name__ == "__main__": print(cfg) \ No newline at end of file diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py old mode 100644 new mode 100755 index 9a559b6..923b565 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -46,7 +46,7 @@ _logger = logging.getLogger(__name__) # A very naive engine implementation using barebones PyTorch class Engine(): def __init__(self, *args, **kwargs): - self.module = kwargs['model'] + self.module = kwargs['model'].to(cfg.device) self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None @@ -83,16 +83,28 @@ class Engine(): return dispatch_attribute(self.module, *args, **kwargs) def save_checkpoint(self, save_dir, tag ): + save_path = save_dir / tag / "state.pth" + save_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ "global_step": self.global_step, "micro_step": self.micro_step, "module": self.module.state_dict(), "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - }, save_dir / tag / "state.pth") + }, save_path) - def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): - state = torch.load(load_dir / tag / "state.pth") + def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): + if tag is None: + tag_path = load_dir / "latest" + if not tag_path.exists(): + return + tag = open(tag_path).read() + + load_path = load_dir / tag / "state.pth" + if not load_path.exists(): + return + + state = torch.load(load_path) self.global_step = state['global_step'] self.micro_step = state['micro_step'] self.module.load_state_dict(state['module'])