This commit is contained in:
James Betker 2021-10-29 20:22:40 -06:00
parent 92fe8b4dd9
commit 466b9fbcaa

View File

@ -51,9 +51,10 @@ if __name__ == "__main__":
s.losses = {}
output_key = opt['eval']['classifier_logits_key']
output_base_dir = opt['eval']['output_dir']
labels = opt['eval']['output_labels']
path_key = opt['eval']['path_key']
output_base_dir = util.opt_get(opt, ['eval', 'output_dir'], None)
output_file = open('classify_into_folders.tsv', 'a')
step = 0
for test_set_name, test_loader in test_loaders:
@ -68,9 +69,14 @@ if __name__ == "__main__":
lbls = torch.nn.functional.softmax(model.eval_state[output_key][0].cpu(), dim=-1)
for k, lbl in enumerate(lbls):
lbl = torch.argmax(lbl, dim=0)
lbl = labels[torch.argmax(lbl, dim=0)]
src_path = data[path_key][k]
dest = os.path.join(output_base_dir, labels[lbl])
os.makedirs(dest, exist_ok=True)
shutil.copy(str(src_path), os.path.join(dest, f'{step}_{os.path.basename(str(src_path))}'))
step += 1
output_file.write(f'{src_path}\t{lbl}')
if output_base_dir is not None:
dest = os.path.join(output_base_dir, lbl)
os.makedirs(dest, exist_ok=True)
shutil.copy(str(src_path), os.path.join(dest, f'{step}_{os.path.basename(str(src_path))}'))
step += 1
output_file.flush()
output_file.close()