27 lines
880 B
Python
27 lines
880 B
Python
from glob import glob
|
|
|
|
import torch
|
|
import os
|
|
import shutil
|
|
|
|
if __name__ == '__main__':
|
|
index_map_file = 'F:\\4k6k\\datasets\\images\\imagenet_2017\\imagenet_index_to_train_folder_name_map.pth'
|
|
ground_truth = 'F:\\4k6k\\datasets\\images\\imagenet_2017\\validation_ground_truth.txt'
|
|
val_path = 'F:\\4k6k\\datasets\\images\\imagenet_2017\\val'
|
|
|
|
index_map = torch.load(index_map_file)
|
|
|
|
for folder in index_map.values():
|
|
os.makedirs(os.path.join(val_path, folder), exist_ok=True)
|
|
|
|
gtfile = open(ground_truth, 'r')
|
|
gtids = []
|
|
for line in gtfile:
|
|
gtids.append(int(line.strip()))
|
|
gtfile.close()
|
|
|
|
for i, img_file in enumerate(glob(os.path.join(val_path, "*.JPEG"))):
|
|
shutil.move(img_file, os.path.join(val_path, index_map[gtids[i]],
|
|
os.path.basename(img_file)))
|
|
print("Done!")
|