diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 7637d864..faa68723 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -52,9 +52,7 @@ class ImageFolderDataset: imgs = torch.load(cache_path) else: print("Building image folder cache, this can take some time for large datasets..") - imgs = [] - for ext in supported_types: - imgs.extend(glob.glob(os.path.join(path, "*." + ext))) + imgs = util.get_image_paths('img', path)[0] torch.save(imgs, cache_path) for w in range(weight): self.image_paths.extend(imgs) @@ -67,6 +65,7 @@ class ImageFolderDataset: def resize_hq(self, imgs_hq): # Enforce size constraints h, w, _ = imgs_hq[0].shape + if self.target_hq_size is not None and self.target_hq_size != h: hqs_adjusted = [] for hq in imgs_hq: @@ -114,6 +113,11 @@ class ImageFolderDataset: if not self.disable_flip and random.random() < .5: hq = hq[:, ::-1, :] + # We must convert the image into a square. + h, w, _ = hq.shape + dim = min(h, w) + hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] + if self.labeler: assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet. dim = hq.shape[0]