diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 2b696834..c882bd7f 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -134,7 +134,7 @@ class FeatureLoss(ConfigurableLoss): self.netF = define_F(which_model=opt['which_model_F'], load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device']) if not env['opt']['dist']: - self.netF = torch.nn.parallel.DataParallel(self.netF) + self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids']) def forward(self, _, state): with autocast(enabled=self.env['opt']['fp16']):