some fixes for the local framework
This commit is contained in:
parent
012f54b7f1
commit
5970f254e3
|
@ -442,6 +442,11 @@ if cfg.dataset.use_hdf5:
|
||||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
|
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||||
|
cfg.dataset.use_hdf5 = False
|
||||||
|
|
||||||
|
if not cfg.dataset.use_hdf5:
|
||||||
|
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||||
|
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(cfg)
|
print(cfg)
|
20
vall_e/engines/base.py
Normal file → Executable file
20
vall_e/engines/base.py
Normal file → Executable file
|
@ -46,7 +46,7 @@ _logger = logging.getLogger(__name__)
|
||||||
# A very naive engine implementation using barebones PyTorch
|
# A very naive engine implementation using barebones PyTorch
|
||||||
class Engine():
|
class Engine():
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.module = kwargs['model']
|
self.module = kwargs['model'].to(cfg.device)
|
||||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||||
|
|
||||||
|
@ -83,16 +83,28 @@ class Engine():
|
||||||
return dispatch_attribute(self.module, *args, **kwargs)
|
return dispatch_attribute(self.module, *args, **kwargs)
|
||||||
|
|
||||||
def save_checkpoint(self, save_dir, tag ):
|
def save_checkpoint(self, save_dir, tag ):
|
||||||
|
save_path = save_dir / tag / "state.pth"
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save({
|
torch.save({
|
||||||
"global_step": self.global_step,
|
"global_step": self.global_step,
|
||||||
"micro_step": self.micro_step,
|
"micro_step": self.micro_step,
|
||||||
"module": self.module.state_dict(),
|
"module": self.module.state_dict(),
|
||||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||||
}, save_dir / tag / "state.pth")
|
}, save_path)
|
||||||
|
|
||||||
def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
||||||
state = torch.load(load_dir / tag / "state.pth")
|
if tag is None:
|
||||||
|
tag_path = load_dir / "latest"
|
||||||
|
if not tag_path.exists():
|
||||||
|
return
|
||||||
|
tag = open(tag_path).read()
|
||||||
|
|
||||||
|
load_path = load_dir / tag / "state.pth"
|
||||||
|
if not load_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
state = torch.load(load_path)
|
||||||
self.global_step = state['global_step']
|
self.global_step = state['global_step']
|
||||||
self.micro_step = state['micro_step']
|
self.micro_step = state['micro_step']
|
||||||
self.module.load_state_dict(state['module'])
|
self.module.load_state_dict(state['module'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user