Fix distributed bug

This commit is contained in:
James Betker 2022-03-04 11:58:53 -07:00
parent f87e10ffef
commit 79e5692388

View File

@ -82,7 +82,7 @@ class DeterministicSampler:
else:
self.world_size = 1
# The sampling range gets spread out across multiple distributed entities.
rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range
rnge = th.arange(self.rank, sampling_range, step=self.world_size).float() / sampling_range
self.indices = (rnge * self.timesteps).long()
def sample(self, batch_size, device):