forked from mrq/DL-Art-School
Fix distributed bug
This commit is contained in:
parent
f87e10ffef
commit
79e5692388
|
@ -82,7 +82,7 @@ class DeterministicSampler:
|
||||||
else:
|
else:
|
||||||
self.world_size = 1
|
self.world_size = 1
|
||||||
# The sampling range gets spread out across multiple distributed entities.
|
# The sampling range gets spread out across multiple distributed entities.
|
||||||
rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range
|
rnge = th.arange(self.rank, sampling_range, step=self.world_size).float() / sampling_range
|
||||||
self.indices = (rnge * self.timesteps).long()
|
self.indices = (rnge * self.timesteps).long()
|
||||||
|
|
||||||
def sample(self, batch_size, device):
|
def sample(self, batch_size, device):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user