Use areal interpolate for multiscale_dataset

This commit is contained in:
James Betker 2020-10-19 15:30:25 -06:00
parent a63bf2ea2f
commit 111450f4e7

View File

@ -78,7 +78,7 @@ class MultiScaleDataset(data.Dataset):
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
patches_hq = torch.stack(patches_hq, dim=0)
patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted]
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq_corrupted]
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
patches_lq = torch.stack(patches_lq, dim=0)
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}