another classic commit so i can copy it to another machine to gut out things and use the trainer bits for a side project that I should really get around to working on sooner than later

This commit is contained in:
mrq 2023-08-04 14:21:30 -05:00
parent 0a524f1d59
commit 012f54b7f1
2 changed files with 2 additions and 3 deletions

View File

@ -66,7 +66,6 @@ class Engine(DeepSpeedEngine):
def set_lr(self, lr):
try:
if hasattr(self.optimizer, 'param_groups'):
print(self.optimizer.param_groups)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
else:

View File

@ -100,9 +100,9 @@ def run_eval(engines, eval_name, dl):
if AR is not None and NAR is not None:
name = "+".join(names)
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
process( name, batch, resps_list )
else: