Fix another issue with gpu ids getting thrown all over hte place

This commit is contained in:
James Betker 2020-11-13 20:05:52 -07:00
parent c47925ae34
commit 0b96811611

View File

@ -134,7 +134,7 @@ class FeatureLoss(ConfigurableLoss):
self.netF = define_F(which_model=opt['which_model_F'],
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
if not env['opt']['dist']:
self.netF = torch.nn.parallel.DataParallel(self.netF)
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
def forward(self, _, state):
with autocast(enabled=self.env['opt']['fp16']):