Fix distributed bug
This commit is contained in:
parent
f87e10ffef
commit
79e5692388
|
@ -82,7 +82,7 @@ class DeterministicSampler:
|
|||
else:
|
||||
self.world_size = 1
|
||||
# The sampling range gets spread out across multiple distributed entities.
|
||||
rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range
|
||||
rnge = th.arange(self.rank, sampling_range, step=self.world_size).float() / sampling_range
|
||||
self.indices = (rnge * self.timesteps).long()
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
|
|
Loading…
Reference in New Issue
Block a user