diff --git a/README.md b/README.md index 3f02770..e6f5896 100755 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ This is a simple ResNet based image classifier for """specific images""", using ## Inferencing -To be implemented. +Simply invoke the inferencer with the following command: `python3 -m captcha "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0` ## Caveats diff --git a/captcha/train.py b/captcha/train.py index e9ce99d..1a02779 100755 --- a/captcha/train.py +++ b/captcha/train.py @@ -57,7 +57,7 @@ def run_eval(engines, eval_name, dl): batch: dict = to_device(batch, cfg.device) # if we're training both models, provide output for both - res = model( image=batch['image'], text=batch['text'], temperature=cfg.evaluation.temperature ) + res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature ) for path, ref, hyp in zip(batch["path"], batch["text"], res): hyp = hyp.replace('', "").replace("", "") @@ -68,7 +68,7 @@ def run_eval(engines, eval_name, dl): image.save(hyp_path) losses = engine.gather_attribute("loss") - loss = torch.stack([*losses.values()]).sum() + loss = torch.stack([*losses.values()]).sum().item() stats['loss'].append(loss)