diff --git a/codes/models/diffusion/resample.py b/codes/models/diffusion/resample.py index 694aa3f0..8690bb1d 100644 --- a/codes/models/diffusion/resample.py +++ b/codes/models/diffusion/resample.py @@ -82,7 +82,7 @@ class DeterministicSampler: else: self.world_size = 1 # The sampling range gets spread out across multiple distributed entities. - rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range + rnge = th.arange(self.rank, sampling_range, step=self.world_size).float() / sampling_range self.indices = (rnge * self.timesteps).long() def sample(self, batch_size, device):