From 466b9fbcaa23c3e006bd2d2f2bf17946368323e1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 29 Oct 2021 20:22:40 -0600 Subject: [PATCH] classify --- codes/scripts/classify_into_folders.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/codes/scripts/classify_into_folders.py b/codes/scripts/classify_into_folders.py index db637078..8f35c689 100644 --- a/codes/scripts/classify_into_folders.py +++ b/codes/scripts/classify_into_folders.py @@ -51,9 +51,10 @@ if __name__ == "__main__": s.losses = {} output_key = opt['eval']['classifier_logits_key'] - output_base_dir = opt['eval']['output_dir'] labels = opt['eval']['output_labels'] path_key = opt['eval']['path_key'] + output_base_dir = util.opt_get(opt, ['eval', 'output_dir'], None) + output_file = open('classify_into_folders.tsv', 'a') step = 0 for test_set_name, test_loader in test_loaders: @@ -68,9 +69,14 @@ if __name__ == "__main__": lbls = torch.nn.functional.softmax(model.eval_state[output_key][0].cpu(), dim=-1) for k, lbl in enumerate(lbls): - lbl = torch.argmax(lbl, dim=0) + lbl = labels[torch.argmax(lbl, dim=0)] src_path = data[path_key][k] - dest = os.path.join(output_base_dir, labels[lbl]) - os.makedirs(dest, exist_ok=True) - shutil.copy(str(src_path), os.path.join(dest, f'{step}_{os.path.basename(str(src_path))}')) - step += 1 + output_file.write(f'{src_path}\t{lbl}') + if output_base_dir is not None: + dest = os.path.join(output_base_dir, lbl) + os.makedirs(dest, exist_ok=True) + shutil.copy(str(src_path), os.path.join(dest, f'{step}_{os.path.basename(str(src_path))}')) + step += 1 + output_file.flush() + output_file.close() +