From b404a3b7476789b692490ea25d1830137b661226 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 30 Oct 2021 20:48:06 -0600 Subject: [PATCH] Revert recent changes to extr --- codes/trainer/ExtensibleTrainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index e2c01ea6..edabeab0 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -21,7 +21,7 @@ logger = logging.getLogger('base') class ExtensibleTrainer(BaseModel): - def __init__(self, opt): + def __init__(self, opt, cached_networks={}): super(ExtensibleTrainer, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() @@ -54,6 +54,10 @@ class ExtensibleTrainer(BaseModel): if 'trainable' not in net.keys(): net['trainable'] = True + if name in cached_networks.keys(): + new_net = cached_networks[name] + else: + new_net = None if net['type'] == 'generator': if new_net is None: new_net = networks.create_model(opt, net, self.netsG).to(self.device)