Fix an issue where GPU0 was always being used in non-ddp
Frankly, I don't understand how this has ever worked. WTF.
This commit is contained in:
parent
2d3449d7a5
commit
db9e9e28a0
|
@ -108,14 +108,14 @@ class ExtensibleTrainer(BaseModel):
|
||||||
device_ids=[torch.cuda.current_device()],
|
device_ids=[torch.cuda.current_device()],
|
||||||
find_unused_parameters=False)
|
find_unused_parameters=False)
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet)
|
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
dnet.train()
|
dnet.train()
|
||||||
else:
|
else:
|
||||||
dnet.eval()
|
dnet.eval()
|
||||||
dnets.append(dnet)
|
dnets.append(dnet)
|
||||||
if not opt['dist']:
|
if not opt['dist']:
|
||||||
self.netF = DataParallel(self.netF)
|
self.netF = DataParallel(self.netF, device_ids=opt['gpu_ids'])
|
||||||
|
|
||||||
# Backpush the wrapped networks into the network dicts..
|
# Backpush the wrapped networks into the network dicts..
|
||||||
self.networks = {}
|
self.networks = {}
|
||||||
|
|
|
@ -284,8 +284,9 @@ if __name__ == '__main__':
|
||||||
if args.launcher == 'none': # disabled distributed training
|
if args.launcher == 'none': # disabled distributed training
|
||||||
opt['dist'] = False
|
opt['dist'] = False
|
||||||
trainer.rank = -1
|
trainer.rank = -1
|
||||||
|
if len(opt['gpu_ids']) == 1:
|
||||||
|
torch.cuda.set_device(opt['gpu_ids'][0])
|
||||||
print('Disabled distributed training.')
|
print('Disabled distributed training.')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
opt['dist'] = True
|
opt['dist'] = True
|
||||||
init_dist('nccl')
|
init_dist('nccl')
|
||||||
|
|
|
@ -284,8 +284,9 @@ if __name__ == '__main__':
|
||||||
if args.launcher == 'none': # disabled distributed training
|
if args.launcher == 'none': # disabled distributed training
|
||||||
opt['dist'] = False
|
opt['dist'] = False
|
||||||
trainer.rank = -1
|
trainer.rank = -1
|
||||||
|
if len(opt['gpu_ids']) == 1:
|
||||||
|
torch.cuda.set_device(opt['gpu_ids'][0])
|
||||||
print('Disabled distributed training.')
|
print('Disabled distributed training.')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
opt['dist'] = True
|
opt['dist'] = True
|
||||||
init_dist('nccl')
|
init_dist('nccl')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user