Fix an issue where GPU0 was always being used in non-ddp

Frankly, I don't understand how this has ever worked. WTF.
This commit is contained in:
James Betker 2020-11-12 15:43:01 -07:00
parent 2d3449d7a5
commit db9e9e28a0
3 changed files with 6 additions and 4 deletions

View File

@ -108,14 +108,14 @@ class ExtensibleTrainer(BaseModel):
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
else:
dnet = DataParallel(anet)
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train:
dnet.train()
else:
dnet.eval()
dnets.append(dnet)
if not opt['dist']:
self.netF = DataParallel(self.netF)
self.netF = DataParallel(self.netF, device_ids=opt['gpu_ids'])
# Backpush the wrapped networks into the network dicts..
self.networks = {}

View File

@ -284,8 +284,9 @@ if __name__ == '__main__':
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')

View File

@ -284,8 +284,9 @@ if __name__ == '__main__':
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')