Fix an issue where GPU0 was always being used in non-ddp

Frankly, I don't understand how this has ever worked. WTF.
This commit is contained in:
James Betker 2020-11-12 15:43:01 -07:00
parent 2d3449d7a5
commit db9e9e28a0
3 changed files with 6 additions and 4 deletions

View File

@ -108,14 +108,14 @@ class ExtensibleTrainer(BaseModel):
device_ids=[torch.cuda.current_device()], device_ids=[torch.cuda.current_device()],
find_unused_parameters=False) find_unused_parameters=False)
else: else:
dnet = DataParallel(anet) dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train: if self.is_train:
dnet.train() dnet.train()
else: else:
dnet.eval() dnet.eval()
dnets.append(dnet) dnets.append(dnet)
if not opt['dist']: if not opt['dist']:
self.netF = DataParallel(self.netF) self.netF = DataParallel(self.netF, device_ids=opt['gpu_ids'])
# Backpush the wrapped networks into the network dicts.. # Backpush the wrapped networks into the network dicts..
self.networks = {} self.networks = {}

View File

@ -284,8 +284,9 @@ if __name__ == '__main__':
if args.launcher == 'none': # disabled distributed training if args.launcher == 'none': # disabled distributed training
opt['dist'] = False opt['dist'] = False
trainer.rank = -1 trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.') print('Disabled distributed training.')
else: else:
opt['dist'] = True opt['dist'] = True
init_dist('nccl') init_dist('nccl')

View File

@ -284,8 +284,9 @@ if __name__ == '__main__':
if args.launcher == 'none': # disabled distributed training if args.launcher == 'none': # disabled distributed training
opt['dist'] = False opt['dist'] = False
trainer.rank = -1 trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.') print('Disabled distributed training.')
else: else:
opt['dist'] = True opt['dist'] = True
init_dist('nccl') init_dist('nccl')