Fix val behavior for ExtensibleTrainer
This commit is contained in:
parent
434ed70a9a
commit
f35b3ad28f
|
@ -137,8 +137,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# Some models need to make parametric adjustments per-step. Do that here.
|
||||
for net in self.networks.values():
|
||||
if hasattr(net, "update_for_step"):
|
||||
net.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
if hasattr(net.module, "update_for_step"):
|
||||
net.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
|
||||
# Iterate through the steps, performing them one at a time.
|
||||
state = self.dstate
|
||||
|
|
|
@ -101,7 +101,10 @@ def gather_2d(input, index):
|
|||
ind_nd = ind_nd.repeat((1, c))
|
||||
ind_nd = ind_nd.unsqueeze(2)
|
||||
result = torch.gather(nodim, dim=2, index=ind_nd)
|
||||
return result.squeeze()
|
||||
result = result.squeeze()
|
||||
if b == 1:
|
||||
result = result.unsqueeze(0)
|
||||
return result
|
||||
|
||||
|
||||
# Computes a linear latent by performing processing on the reference image and returning the filters of a single point,
|
||||
|
|
|
@ -215,7 +215,7 @@ def main():
|
|||
logger.info(message)
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||
model.force_restore_swapout()
|
||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||
# does not support multi-GPU validation
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_nt_spsr_switched.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user