ugh
This commit is contained in:
parent
f9f309281a
commit
4ade2b60ee
|
@ -429,7 +429,7 @@ class Dataset(_Dataset):
|
||||||
self.training = training
|
self.training = training
|
||||||
self.dataset_type = "training" if self.training else "validation"
|
self.dataset_type = "training" if self.training else "validation"
|
||||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||||
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "group"
|
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
|
||||||
|
|
||||||
# to-do: do not do validation if there's nothing in the validation
|
# to-do: do not do validation if there's nothing in the validation
|
||||||
# this just makes it be happy
|
# this just makes it be happy
|
||||||
|
|
|
@ -86,7 +86,7 @@ class Engine():
|
||||||
self._frozen_params.clear()
|
self._frozen_params.clear()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def training(self):
|
def _training(self):
|
||||||
if not hasattr(self, "hyper_config"):
|
if not hasattr(self, "hyper_config"):
|
||||||
return True
|
return True
|
||||||
return self.hyper_config.training
|
return self.hyper_config.training
|
||||||
|
@ -321,7 +321,7 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
if not engine.training:
|
if not engine._training:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
save_dir = cfg.ckpt_dir / name
|
save_dir = cfg.ckpt_dir / name
|
||||||
|
|
|
@ -77,7 +77,7 @@ class Engine(DeepSpeedEngine):
|
||||||
self._frozen_params.clear()
|
self._frozen_params.clear()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def training(self):
|
def _training(self):
|
||||||
return self.hyper_config.training
|
return self.hyper_config.training
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -57,7 +57,7 @@ def train_feeder(engine, batch):
|
||||||
else:
|
else:
|
||||||
engine(
|
engine(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
proms_list=batch["proms"],
|
||||||
resps_list=batch["resps"],
|
resps_list=batch["resps"],
|
||||||
lang_list=batch["lang"],
|
lang_list=batch["lang"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -157,9 +157,6 @@ def train(
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for batch in _make_infinite_epochs(train_dl):
|
for batch in _make_infinite_epochs(train_dl):
|
||||||
if not engine.training:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if engines.global_step >= cfg.trainer.iterations:
|
if engines.global_step >= cfg.trainer.iterations:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user