Fix val behavior for ExtensibleTrainer

This commit is contained in:
James Betker 2020-08-26 08:44:22 -06:00
parent 434ed70a9a
commit f35b3ad28f
4 changed files with 8 additions and 5 deletions

View File

@ -137,8 +137,8 @@ class ExtensibleTrainer(BaseModel):
# Some models need to make parametric adjustments per-step. Do that here. # Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values(): for net in self.networks.values():
if hasattr(net, "update_for_step"): if hasattr(net.module, "update_for_step"):
net.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) net.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
# Iterate through the steps, performing them one at a time. # Iterate through the steps, performing them one at a time.
state = self.dstate state = self.dstate

View File

@ -101,7 +101,10 @@ def gather_2d(input, index):
ind_nd = ind_nd.repeat((1, c)) ind_nd = ind_nd.repeat((1, c))
ind_nd = ind_nd.unsqueeze(2) ind_nd = ind_nd.unsqueeze(2)
result = torch.gather(nodim, dim=2, index=ind_nd) result = torch.gather(nodim, dim=2, index=ind_nd)
return result.squeeze() result = result.squeeze()
if b == 1:
result = result.unsqueeze(0)
return result
# Computes a linear latent by performing processing on the reference image and returning the filters of a single point, # Computes a linear latent by performing processing on the reference image and returning the filters of a single point,

View File

@ -215,7 +215,7 @@ def main():
logger.info(message) logger.info(message)
#### validation #### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
model.force_restore_swapout() model.force_restore_swapout()
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
# does not support multi-GPU validation # does not support multi-GPU validation

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_nt_spsr_switched.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)