classify
This commit is contained in:
parent
92fe8b4dd9
commit
466b9fbcaa
|
@ -51,9 +51,10 @@ if __name__ == "__main__":
|
||||||
s.losses = {}
|
s.losses = {}
|
||||||
|
|
||||||
output_key = opt['eval']['classifier_logits_key']
|
output_key = opt['eval']['classifier_logits_key']
|
||||||
output_base_dir = opt['eval']['output_dir']
|
|
||||||
labels = opt['eval']['output_labels']
|
labels = opt['eval']['output_labels']
|
||||||
path_key = opt['eval']['path_key']
|
path_key = opt['eval']['path_key']
|
||||||
|
output_base_dir = util.opt_get(opt, ['eval', 'output_dir'], None)
|
||||||
|
output_file = open('classify_into_folders.tsv', 'a')
|
||||||
|
|
||||||
step = 0
|
step = 0
|
||||||
for test_set_name, test_loader in test_loaders:
|
for test_set_name, test_loader in test_loaders:
|
||||||
|
@ -68,9 +69,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
lbls = torch.nn.functional.softmax(model.eval_state[output_key][0].cpu(), dim=-1)
|
lbls = torch.nn.functional.softmax(model.eval_state[output_key][0].cpu(), dim=-1)
|
||||||
for k, lbl in enumerate(lbls):
|
for k, lbl in enumerate(lbls):
|
||||||
lbl = torch.argmax(lbl, dim=0)
|
lbl = labels[torch.argmax(lbl, dim=0)]
|
||||||
src_path = data[path_key][k]
|
src_path = data[path_key][k]
|
||||||
dest = os.path.join(output_base_dir, labels[lbl])
|
output_file.write(f'{src_path}\t{lbl}')
|
||||||
os.makedirs(dest, exist_ok=True)
|
if output_base_dir is not None:
|
||||||
shutil.copy(str(src_path), os.path.join(dest, f'{step}_{os.path.basename(str(src_path))}'))
|
dest = os.path.join(output_base_dir, lbl)
|
||||||
step += 1
|
os.makedirs(dest, exist_ok=True)
|
||||||
|
shutil.copy(str(src_path), os.path.join(dest, f'{step}_{os.path.basename(str(src_path))}'))
|
||||||
|
step += 1
|
||||||
|
output_file.flush()
|
||||||
|
output_file.close()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user