forked from mrq/DL-Art-School
Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
3397c83447
|
@ -134,7 +134,7 @@ class FeatureLoss(ConfigurableLoss):
|
||||||
self.netF = define_F(which_model=opt['which_model_F'],
|
self.netF = define_F(which_model=opt['which_model_F'],
|
||||||
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||||
if not env['opt']['dist']:
|
if not env['opt']['dist']:
|
||||||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
with autocast(enabled=self.env['opt']['fp16']):
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
|
|
@ -221,7 +221,7 @@ class Trainer:
|
||||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||||
util.mkdir(img_dir)
|
util.mkdir(img_dir)
|
||||||
|
|
||||||
self.model.feed_data(val_data)
|
self.model.feed_data(val_data, self.current_step)
|
||||||
self.model.test()
|
self.model.test()
|
||||||
|
|
||||||
visuals = self.model.get_current_visuals()
|
visuals = self.model.get_current_visuals()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user