Fix another issue with gpu ids getting thrown all over hte place
This commit is contained in:
parent
c47925ae34
commit
0b96811611
|
@ -134,7 +134,7 @@ class FeatureLoss(ConfigurableLoss):
|
|||
self.netF = define_F(which_model=opt['which_model_F'],
|
||||
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||
if not env['opt']['dist']:
|
||||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
|
||||
|
||||
def forward(self, _, state):
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
|
|
Loading…
Reference in New Issue
Block a user