some fixes for the local framework

This commit is contained in:
mrq 2023-08-05 02:17:30 +00:00
parent 012f54b7f1
commit 5970f254e3
2 changed files with 21 additions and 4 deletions

View File

@ -442,6 +442,11 @@ if cfg.dataset.use_hdf5:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
if __name__ == "__main__":
print(cfg)

20
vall_e/engines/base.py Normal file → Executable file
View File

@ -46,7 +46,7 @@ _logger = logging.getLogger(__name__)
# A very naive engine implementation using barebones PyTorch
class Engine():
def __init__(self, *args, **kwargs):
self.module = kwargs['model']
self.module = kwargs['model'].to(cfg.device)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
@ -83,16 +83,28 @@ class Engine():
return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ):
save_path = save_dir / tag / "state.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"global_step": self.global_step,
"micro_step": self.micro_step,
"module": self.module.state_dict(),
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
}, save_dir / tag / "state.pth")
}, save_path)
def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
state = torch.load(load_dir / tag / "state.pth")
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
if tag is None:
tag_path = load_dir / "latest"
if not tag_path.exists():
return
tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth"
if not load_path.exists():
return
state = torch.load(load_path)
self.global_step = state['global_step']
self.micro_step = state['micro_step']
self.module.load_state_dict(state['module'])