Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-11-14 09:30:09 -07:00
commit 3397c83447
2 changed files with 2 additions and 2 deletions

View File

@ -134,7 +134,7 @@ class FeatureLoss(ConfigurableLoss):
self.netF = define_F(which_model=opt['which_model_F'],
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
if not env['opt']['dist']:
self.netF = torch.nn.parallel.DataParallel(self.netF)
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
def forward(self, _, state):
with autocast(enabled=self.env['opt']['fp16']):

View File

@ -221,7 +221,7 @@ class Trainer:
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
self.model.feed_data(val_data)
self.model.feed_data(val_data, self.current_step)
self.model.test()
visuals = self.model.get_current_visuals()