diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index e4f297ab..2ba9b1e2 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -219,7 +219,7 @@ class FullImageDataset(data.Dataset): img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full) else: img_GT, gt_fullsize_ref = img_full, img_full - gt_mask = np.ones(img_full.shape[:2]) + gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype) gt_center = torch.tensor([img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long) orig_gt_dim = gt_fullsize_ref.shape[:2]