some fixes for the local framework

This commit is contained in:
mrq 2023-08-05 02:17:30 +00:00
parent 012f54b7f1
commit 5970f254e3
2 changed files with 21 additions and 4 deletions

View File

@ -442,6 +442,11 @@ if cfg.dataset.use_hdf5:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a') cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
except Exception as e: except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
if __name__ == "__main__": if __name__ == "__main__":
print(cfg) print(cfg)

20
vall_e/engines/base.py Normal file → Executable file
View File

@ -46,7 +46,7 @@ _logger = logging.getLogger(__name__)
# A very naive engine implementation using barebones PyTorch # A very naive engine implementation using barebones PyTorch
class Engine(): class Engine():
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.module = kwargs['model'] self.module = kwargs['model'].to(cfg.device)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
@ -83,16 +83,28 @@ class Engine():
return dispatch_attribute(self.module, *args, **kwargs) return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ): def save_checkpoint(self, save_dir, tag ):
save_path = save_dir / tag / "state.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({ torch.save({
"global_step": self.global_step, "global_step": self.global_step,
"micro_step": self.micro_step, "micro_step": self.micro_step,
"module": self.module.state_dict(), "module": self.module.state_dict(),
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
}, save_dir / tag / "state.pth") }, save_path)
def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
state = torch.load(load_dir / tag / "state.pth") if tag is None:
tag_path = load_dir / "latest"
if not tag_path.exists():
return
tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth"
if not load_path.exists():
return
state = torch.load(load_path)
self.global_step = state['global_step'] self.global_step = state['global_step']
self.micro_step = state['micro_step'] self.micro_step = state['micro_step']
self.module.load_state_dict(state['module']) self.module.load_state_dict(state['module'])