another classic commit so i can copy it to another machine to gut out things and use the trainer bits for a side project that I should really get around to working on sooner than later
This commit is contained in:
parent
0a524f1d59
commit
012f54b7f1
|
@ -66,7 +66,6 @@ class Engine(DeepSpeedEngine):
|
|||
def set_lr(self, lr):
|
||||
try:
|
||||
if hasattr(self.optimizer, 'param_groups'):
|
||||
print(self.optimizer.param_groups)
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
else:
|
||||
|
|
|
@ -100,9 +100,9 @@ def run_eval(engines, eval_name, dl):
|
|||
if AR is not None and NAR is not None:
|
||||
name = "+".join(names)
|
||||
|
||||
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
|
||||
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
|
||||
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
|
||||
process( name, batch, resps_list )
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user