An amazing commit (:
This commit is contained in:
parent
bcdbc71b58
commit
62cba62bbb
|
@ -16,7 +16,7 @@ This is a simple ResNet based image classifier for """specific images""", using
|
|||
|
||||
## Inferencing
|
||||
|
||||
To be implemented.
|
||||
Simply invoke the inferencer with the following command: `python3 -m captcha "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
||||
|
||||
## Caveats
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def run_eval(engines, eval_name, dl):
|
|||
batch: dict = to_device(batch, cfg.device)
|
||||
|
||||
# if we're training both models, provide output for both
|
||||
res = model( image=batch['image'], text=batch['text'], temperature=cfg.evaluation.temperature )
|
||||
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
|
||||
|
||||
for path, ref, hyp in zip(batch["path"], batch["text"], res):
|
||||
hyp = hyp.replace('<s>', "").replace("</s>", "")
|
||||
|
@ -68,7 +68,7 @@ def run_eval(engines, eval_name, dl):
|
|||
image.save(hyp_path)
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
loss = torch.stack([*losses.values()]).sum().item()
|
||||
|
||||
stats['loss'].append(loss)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user