classify
This commit is contained in:
parent
92fe8b4dd9
commit
466b9fbcaa
|
@ -51,9 +51,10 @@ if __name__ == "__main__":
|
|||
s.losses = {}
|
||||
|
||||
output_key = opt['eval']['classifier_logits_key']
|
||||
output_base_dir = opt['eval']['output_dir']
|
||||
labels = opt['eval']['output_labels']
|
||||
path_key = opt['eval']['path_key']
|
||||
output_base_dir = util.opt_get(opt, ['eval', 'output_dir'], None)
|
||||
output_file = open('classify_into_folders.tsv', 'a')
|
||||
|
||||
step = 0
|
||||
for test_set_name, test_loader in test_loaders:
|
||||
|
@ -68,9 +69,14 @@ if __name__ == "__main__":
|
|||
|
||||
lbls = torch.nn.functional.softmax(model.eval_state[output_key][0].cpu(), dim=-1)
|
||||
for k, lbl in enumerate(lbls):
|
||||
lbl = torch.argmax(lbl, dim=0)
|
||||
lbl = labels[torch.argmax(lbl, dim=0)]
|
||||
src_path = data[path_key][k]
|
||||
dest = os.path.join(output_base_dir, labels[lbl])
|
||||
output_file.write(f'{src_path}\t{lbl}')
|
||||
if output_base_dir is not None:
|
||||
dest = os.path.join(output_base_dir, lbl)
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
shutil.copy(str(src_path), os.path.join(dest, f'{step}_{os.path.basename(str(src_path))}'))
|
||||
step += 1
|
||||
output_file.flush()
|
||||
output_file.close()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user