diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index b97460a8..5634b59a 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -380,6 +380,19 @@ class SRGANModel(BaseModel): # Just a note: this intentionally includes the swap model in the list of possibilities. return previous_models[random.randint(0, len(previous_models)-1)] + # Called before verification/checkpoint to ensure we're using the real models and not a swapout variant. + def force_restore_swapout(self): + if self.swapout_D_duration > 0: + logger.info("Swapping back to current D model: %s" % (self.stashed_D,)) + self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load']) + self.stashed_D = None + self.swapout_D_duration = 0 + if self.swapout_G_duration > 0: + logger.info("Swapping back to current G model: %s" % (self.stashed_G,)) + self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load']) + self.stashed_G = None + self.swapout_G_duration = 0 + def swapout_D(self, step): if self.swapout_D_duration > 0: self.swapout_D_duration -= 1 diff --git a/codes/train.py b/codes/train.py index 8281db8c..f81a3484 100644 --- a/codes/train.py +++ b/codes/train.py @@ -185,6 +185,7 @@ def main(): #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation + model.force_restore_swapout() # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0.