An amazing commit (:

This commit is contained in:
mrq 2023-08-05 03:44:38 +00:00
parent bcdbc71b58
commit 62cba62bbb
2 changed files with 3 additions and 3 deletions

View File

@ -16,7 +16,7 @@ This is a simple ResNet based image classifier for """specific images""", using
## Inferencing ## Inferencing
To be implemented. Simply invoke the inferencer with the following command: `python3 -m captcha "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
## Caveats ## Caveats

View File

@ -57,7 +57,7 @@ def run_eval(engines, eval_name, dl):
batch: dict = to_device(batch, cfg.device) batch: dict = to_device(batch, cfg.device)
# if we're training both models, provide output for both # if we're training both models, provide output for both
res = model( image=batch['image'], text=batch['text'], temperature=cfg.evaluation.temperature ) res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
for path, ref, hyp in zip(batch["path"], batch["text"], res): for path, ref, hyp in zip(batch["path"], batch["text"], res):
hyp = hyp.replace('<s>', "").replace("</s>", "") hyp = hyp.replace('<s>', "").replace("</s>", "")
@ -68,7 +68,7 @@ def run_eval(engines, eval_name, dl):
image.save(hyp_path) image.save(hyp_path)
losses = engine.gather_attribute("loss") losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum() loss = torch.stack([*losses.values()]).sum().item()
stats['loss'].append(loss) stats['loss'].append(loss)