diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 501d6f67..919a945a 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -137,8 +137,8 @@ class ExtensibleTrainer(BaseModel): # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values(): - if hasattr(net, "update_for_step"): - net.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) + if hasattr(net.module, "update_for_step"): + net.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) # Iterate through the steps, performing them one at a time. state = self.dstate diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 58b02c86..377e0f61 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -101,7 +101,10 @@ def gather_2d(input, index): ind_nd = ind_nd.repeat((1, c)) ind_nd = ind_nd.unsqueeze(2) result = torch.gather(nodim, dim=2, index=ind_nd) - return result.squeeze() + result = result.squeeze() + if b == 1: + result = result.unsqueeze(0) + return result # Computes a linear latent by performing processing on the reference image and returning the filters of a single point, diff --git a/codes/train.py b/codes/train.py index 980d00c2..4c3e8761 100644 --- a/codes/train.py +++ b/codes/train.py @@ -215,7 +215,7 @@ def main(): logger.info(message) #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: - if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation + if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation model.force_restore_swapout() val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] # does not support multi-GPU validation diff --git a/codes/train2.py b/codes/train2.py index 8dccf326..112cc83b 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_nt_spsr_switched.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)