An amazing commit (:
This commit is contained in:
parent
bcdbc71b58
commit
62cba62bbb
|
@ -16,7 +16,7 @@ This is a simple ResNet based image classifier for """specific images""", using
|
||||||
|
|
||||||
## Inferencing
|
## Inferencing
|
||||||
|
|
||||||
To be implemented.
|
Simply invoke the inferencer with the following command: `python3 -m captcha "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
||||||
|
|
||||||
## Caveats
|
## Caveats
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
batch: dict = to_device(batch, cfg.device)
|
batch: dict = to_device(batch, cfg.device)
|
||||||
|
|
||||||
# if we're training both models, provide output for both
|
# if we're training both models, provide output for both
|
||||||
res = model( image=batch['image'], text=batch['text'], temperature=cfg.evaluation.temperature )
|
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
|
||||||
|
|
||||||
for path, ref, hyp in zip(batch["path"], batch["text"], res):
|
for path, ref, hyp in zip(batch["path"], batch["text"], res):
|
||||||
hyp = hyp.replace('<s>', "").replace("</s>", "")
|
hyp = hyp.replace('<s>', "").replace("</s>", "")
|
||||||
|
@ -68,7 +68,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
image.save(hyp_path)
|
image.save(hyp_path)
|
||||||
|
|
||||||
losses = engine.gather_attribute("loss")
|
losses = engine.gather_attribute("loss")
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum().item()
|
||||||
|
|
||||||
stats['loss'].append(loss)
|
stats['loss'].append(loss)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user