This commit is contained in:
mrq 2024-06-06 21:57:11 -05:00
parent f9f309281a
commit 4ade2b60ee
5 changed files with 5 additions and 8 deletions

View File

@ -429,7 +429,7 @@ class Dataset(_Dataset):
self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "group"
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy

View File

@ -86,7 +86,7 @@ class Engine():
self._frozen_params.clear()
@property
def training(self):
def _training(self):
if not hasattr(self, "hyper_config"):
return True
return self.hyper_config.training
@ -321,7 +321,7 @@ class Engines(dict[str, Engine]):
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
if not engine.training:
if not engine._training:
continue
save_dir = cfg.ckpt_dir / name

View File

@ -77,7 +77,7 @@ class Engine(DeepSpeedEngine):
self._frozen_params.clear()
@property
def training(self):
def _training(self):
return self.hyper_config.training
@property

View File

@ -57,7 +57,7 @@ def train_feeder(engine, batch):
else:
engine(
text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
proms_list=batch["proms"],
resps_list=batch["resps"],
lang_list=batch["lang"],
)

View File

@ -157,9 +157,6 @@ def train(
# Training loop
for batch in _make_infinite_epochs(train_dl):
if not engine.training:
continue
if engines.global_step >= cfg.trainer.iterations:
break