Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
3397c83447
|
@ -134,7 +134,7 @@ class FeatureLoss(ConfigurableLoss):
|
|||
self.netF = define_F(which_model=opt['which_model_F'],
|
||||
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||
if not env['opt']['dist']:
|
||||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
|
||||
|
||||
def forward(self, _, state):
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
|
|
|
@ -221,7 +221,7 @@ class Trainer:
|
|||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
util.mkdir(img_dir)
|
||||
|
||||
self.model.feed_data(val_data)
|
||||
self.model.feed_data(val_data, self.current_step)
|
||||
self.model.test()
|
||||
|
||||
visuals = self.model.get_current_visuals()
|
||||
|
|
Loading…
Reference in New Issue
Block a user