Restore swapout models just before a checkpoint
This commit is contained in:
parent
877be4d88c
commit
635c53475f
|
@ -380,6 +380,19 @@ class SRGANModel(BaseModel):
|
|||
# Just a note: this intentionally includes the swap model in the list of possibilities.
|
||||
return previous_models[random.randint(0, len(previous_models)-1)]
|
||||
|
||||
# Called before verification/checkpoint to ensure we're using the real models and not a swapout variant.
|
||||
def force_restore_swapout(self):
|
||||
if self.swapout_D_duration > 0:
|
||||
logger.info("Swapping back to current D model: %s" % (self.stashed_D,))
|
||||
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
||||
self.stashed_D = None
|
||||
self.swapout_D_duration = 0
|
||||
if self.swapout_G_duration > 0:
|
||||
logger.info("Swapping back to current G model: %s" % (self.stashed_G,))
|
||||
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
|
||||
self.stashed_G = None
|
||||
self.swapout_G_duration = 0
|
||||
|
||||
def swapout_D(self, step):
|
||||
if self.swapout_D_duration > 0:
|
||||
self.swapout_D_duration -= 1
|
||||
|
|
|
@ -185,6 +185,7 @@ def main():
|
|||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
|
||||
model.force_restore_swapout()
|
||||
# does not support multi-GPU validation
|
||||
pbar = util.ProgressBar(len(val_loader))
|
||||
avg_psnr = 0.
|
||||
|
|
Loading…
Reference in New Issue
Block a user