Restore swapout models just before a checkpoint

This commit is contained in:
James Betker 2020-05-16 07:45:19 -06:00
parent 877be4d88c
commit 635c53475f
2 changed files with 14 additions and 0 deletions

View File

@ -380,6 +380,19 @@ class SRGANModel(BaseModel):
# Just a note: this intentionally includes the swap model in the list of possibilities. # Just a note: this intentionally includes the swap model in the list of possibilities.
return previous_models[random.randint(0, len(previous_models)-1)] return previous_models[random.randint(0, len(previous_models)-1)]
# Called before verification/checkpoint to ensure we're using the real models and not a swapout variant.
def force_restore_swapout(self):
if self.swapout_D_duration > 0:
logger.info("Swapping back to current D model: %s" % (self.stashed_D,))
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
self.stashed_D = None
self.swapout_D_duration = 0
if self.swapout_G_duration > 0:
logger.info("Swapping back to current G model: %s" % (self.stashed_G,))
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
self.stashed_G = None
self.swapout_G_duration = 0
def swapout_D(self, step): def swapout_D(self, step):
if self.swapout_D_duration > 0: if self.swapout_D_duration > 0:
self.swapout_D_duration -= 1 self.swapout_D_duration -= 1

View File

@ -185,6 +185,7 @@ def main():
#### validation #### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
model.force_restore_swapout()
# does not support multi-GPU validation # does not support multi-GPU validation
pbar = util.ProgressBar(len(val_loader)) pbar = util.ProgressBar(len(val_loader))
avg_psnr = 0. avg_psnr = 0.