sanity cleanup, backup config yaml for each log file
This commit is contained in:
parent
8d92dac829
commit
132a02c48b
|
@ -23,31 +23,35 @@ from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BaseConfig:
|
class BaseConfig:
|
||||||
cfg_path: str | None = None
|
yaml_path: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relpath(self):
|
def cfg_path(self):
|
||||||
|
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rel_path(self):
|
||||||
return Path(self.cfg_path)
|
return Path(self.cfg_path)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_dir(self):
|
def cache_dir(self):
|
||||||
return self.relpath / ".cache"
|
return self.rel_path / ".cache"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_dir(self):
|
def data_dir(self):
|
||||||
return self.relpath / "data"
|
return self.rel_path / "data"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata_dir(self):
|
def metadata_dir(self):
|
||||||
return self.relpath / "metadata"
|
return self.rel_path / "metadata"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ckpt_dir(self):
|
def ckpt_dir(self):
|
||||||
return self.relpath / "ckpt"
|
return self.rel_path / "ckpt"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log_dir(self):
|
def log_dir(self):
|
||||||
return self.relpath / "logs" / str(self.start_time)
|
return self.rel_path / "logs" / str(self.start_time)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def start_time(self):
|
def start_time(self):
|
||||||
|
@ -98,9 +102,9 @@ class BaseConfig:
|
||||||
|
|
||||||
state = {}
|
state = {}
|
||||||
if args.yaml:
|
if args.yaml:
|
||||||
cfg_path = args.yaml
|
yaml_path = args.yaml
|
||||||
state = yaml.safe_load(open(cfg_path, "r", encoding="utf-8"))
|
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
||||||
state.setdefault("cfg_path", cfg_path.parent)
|
state.setdefault("yaml_path", yaml_path)
|
||||||
|
|
||||||
return cls(**state)
|
return cls(**state)
|
||||||
|
|
||||||
|
@ -376,10 +380,10 @@ class DeepSpeed:
|
||||||
autotune_params['enabled'] = True
|
autotune_params['enabled'] = True
|
||||||
|
|
||||||
if "results_dir" not in autotune_params:
|
if "results_dir" not in autotune_params:
|
||||||
autotune_params['results_dir'] = str( cfg.relpath / "autotune" / "results" )
|
autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" )
|
||||||
|
|
||||||
if "exps_dir" not in autotune_params:
|
if "exps_dir" not in autotune_params:
|
||||||
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
|
autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" )
|
||||||
|
|
||||||
# DeepSpeed fp16 is incompatible with its AMP
|
# DeepSpeed fp16 is incompatible with its AMP
|
||||||
if cfg.trainer.weight_dtype.lower() == "float16":
|
if cfg.trainer.weight_dtype.lower() == "float16":
|
||||||
|
@ -653,7 +657,7 @@ class Config(BaseConfig):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def diskcache(self):
|
def diskcache(self):
|
||||||
if self.cfg_path is not None and self.dataset.cache:
|
if self.yaml_path is not None and self.dataset.cache:
|
||||||
return diskcache.Cache(self.cache_dir).memoize
|
return diskcache.Cache(self.cache_dir).memoize
|
||||||
return lambda: lambda x: x
|
return lambda: lambda x: x
|
||||||
|
|
||||||
|
@ -669,9 +673,9 @@ class Config(BaseConfig):
|
||||||
if self.distributed:
|
if self.distributed:
|
||||||
self.dataset.hdf5_flag = "r"
|
self.dataset.hdf5_flag = "r"
|
||||||
try:
|
try:
|
||||||
self.hdf5 = h5py.File(f'{self.relpath}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while opening HDF5 file:", f'{self.relpath}/{self.dataset.hdf5_name}', str(e))
|
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
|
||||||
self.dataset.use_hdf5 = False
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
def format( self ):
|
def format( self ):
|
||||||
|
@ -790,7 +794,7 @@ except Exception as e:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
cfg.tokenizer = (cfg.relpath if cfg.cfg_path is not None else Path("./data/")) / cfg.tokenizer
|
cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
|
||||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
|
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.tokenizer = NaiveTokenizer()
|
cfg.tokenizer = NaiveTokenizer()
|
||||||
|
|
|
@ -965,7 +965,7 @@ def create_datasets():
|
||||||
train_dataset = Dataset( training=True )
|
train_dataset = Dataset( training=True )
|
||||||
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||||
|
|
||||||
train_state_path = cfg.relpath / f"sampler.rank{global_rank()}.pt"
|
train_state_path = cfg.rel_path / f"sampler.rank{global_rank()}.pt"
|
||||||
if train_state_path.exists():
|
if train_state_path.exists():
|
||||||
train_dataset.load_state_dict( train_state_path )
|
train_dataset.load_state_dict( train_state_path )
|
||||||
|
|
||||||
|
@ -1286,7 +1286,7 @@ if __name__ == "__main__":
|
||||||
for i in range(len(v)):
|
for i in range(len(v)):
|
||||||
print(f'{k}[{i}]:', v[i])
|
print(f'{k}[{i}]:', v[i])
|
||||||
|
|
||||||
#train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
#train_dl.dataset.save_state_dict(cfg.rel_path / "train_dataset.pt")
|
||||||
|
|
||||||
elif args.action == "tasks":
|
elif args.action == "tasks":
|
||||||
index = 0
|
index = 0
|
||||||
|
|
|
@ -186,7 +186,7 @@ def load_engines(training=True):
|
||||||
|
|
||||||
# copy embeddings if requested
|
# copy embeddings if requested
|
||||||
if cfg.model._embeddings is not None:
|
if cfg.model._embeddings is not None:
|
||||||
embeddings_path = cfg.relpath / cfg.model._embeddings
|
embeddings_path = cfg.rel_path / cfg.model._embeddings
|
||||||
|
|
||||||
if embeddings_path.exists():
|
if embeddings_path.exists():
|
||||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
||||||
|
|
|
@ -432,9 +432,9 @@ def example_usage():
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 200
|
steps = 200
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
|
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
|
||||||
|
|
||||||
if cfg.optimizations.dadaptation:
|
if cfg.optimizations.dadaptation:
|
||||||
# do not combine the two
|
# do not combine the two
|
||||||
|
|
|
@ -257,9 +257,9 @@ def example_usage():
|
||||||
elif cfg.model.interleave:
|
elif cfg.model.interleave:
|
||||||
steps = 250
|
steps = 250
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
|
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
|
||||||
|
|
||||||
if cfg.optimizations.dadaptation:
|
if cfg.optimizations.dadaptation:
|
||||||
# do not combine the two
|
# do not combine the two
|
||||||
|
|
|
@ -357,9 +357,9 @@ def example_usage():
|
||||||
model = NAR(**kwargs).to(device)
|
model = NAR(**kwargs).to(device)
|
||||||
steps = 200
|
steps = 200
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
|
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
|
||||||
|
|
||||||
if cfg.optimizations.dadaptation:
|
if cfg.optimizations.dadaptation:
|
||||||
# do not combine the two
|
# do not combine the two
|
||||||
|
|
|
@ -106,7 +106,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--group-level", default=1)
|
parser.add_argument("--group-level", default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = cfg.relpath / "logs"
|
path = cfg.rel_path / "logs"
|
||||||
paths = path.rglob(f"./*/{args.filename}")
|
paths = path.rglob(f"./*/{args.filename}")
|
||||||
|
|
||||||
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
|
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
|
||||||
|
@ -116,5 +116,5 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
plot(paths, args)
|
plot(paths, args)
|
||||||
|
|
||||||
out_path = cfg.relpath / "metrics.png"
|
out_path = cfg.rel_path / "metrics.png"
|
||||||
plt.savefig(out_path, bbox_inches="tight")
|
plt.savefig(out_path, bbox_inches="tight")
|
|
@ -14,6 +14,7 @@ import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import traceback
|
import traceback
|
||||||
|
import shutil
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
@ -201,7 +202,11 @@ def train():
|
||||||
parser.add_argument("--eval", action="store_true", default=None)
|
parser.add_argument("--eval", action="store_true", default=None)
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
# create log folder
|
||||||
setup_logging(cfg.log_dir)
|
setup_logging(cfg.log_dir)
|
||||||
|
# copy config yaml to backup
|
||||||
|
if cfg.yaml_path is not None:
|
||||||
|
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
|
|
|
@ -218,7 +218,7 @@ def train(
|
||||||
print("Failed to set LR rate to:", rate, str(e))
|
print("Failed to set LR rate to:", rate, str(e))
|
||||||
|
|
||||||
if "export" in command:
|
if "export" in command:
|
||||||
train_dl.dataset.save_state_dict(cfg.relpath / f"sampler.rank{global_rank()}.pt")
|
train_dl.dataset.save_state_dict(cfg.rel_path / f"sampler.rank{global_rank()}.pt")
|
||||||
engines.save_checkpoint()
|
engines.save_checkpoint()
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ def train(
|
||||||
|
|
||||||
if engines.global_step != last_save_step:
|
if engines.global_step != last_save_step:
|
||||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||||
train_dl.dataset.save_state_dict(cfg.relpath / f"sampler.rank{global_rank()}.pt")
|
train_dl.dataset.save_state_dict(cfg.rel_path / f"sampler.rank{global_rank()}.pt")
|
||||||
engines.save_checkpoint()
|
engines.save_checkpoint()
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user