some fixes for the local framework
This commit is contained in:
parent
012f54b7f1
commit
5970f254e3
|
@ -442,6 +442,11 @@ if cfg.dataset.use_hdf5:
|
|||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
||||
if not cfg.dataset.use_hdf5:
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(cfg)
|
20
vall_e/engines/base.py
Normal file → Executable file
20
vall_e/engines/base.py
Normal file → Executable file
|
@ -46,7 +46,7 @@ _logger = logging.getLogger(__name__)
|
|||
# A very naive engine implementation using barebones PyTorch
|
||||
class Engine():
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.module = kwargs['model']
|
||||
self.module = kwargs['model'].to(cfg.device)
|
||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
||||
|
@ -83,16 +83,28 @@ class Engine():
|
|||
return dispatch_attribute(self.module, *args, **kwargs)
|
||||
|
||||
def save_checkpoint(self, save_dir, tag ):
|
||||
save_path = save_dir / tag / "state.pth"
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save({
|
||||
"global_step": self.global_step,
|
||||
"micro_step": self.micro_step,
|
||||
"module": self.module.state_dict(),
|
||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||
}, save_dir / tag / "state.pth")
|
||||
}, save_path)
|
||||
|
||||
def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
||||
state = torch.load(load_dir / tag / "state.pth")
|
||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
||||
if tag is None:
|
||||
tag_path = load_dir / "latest"
|
||||
if not tag_path.exists():
|
||||
return
|
||||
tag = open(tag_path).read()
|
||||
|
||||
load_path = load_dir / tag / "state.pth"
|
||||
if not load_path.exists():
|
||||
return
|
||||
|
||||
state = torch.load(load_path)
|
||||
self.global_step = state['global_step']
|
||||
self.micro_step = state['micro_step']
|
||||
self.module.load_state_dict(state['module'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user