diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index e2c01ea6..edabeab0 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -21,7 +21,7 @@ logger = logging.getLogger('base') class ExtensibleTrainer(BaseModel): - def __init__(self, opt): + def __init__(self, opt, cached_networks={}): super(ExtensibleTrainer, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() @@ -54,6 +54,10 @@ class ExtensibleTrainer(BaseModel): if 'trainable' not in net.keys(): net['trainable'] = True + if name in cached_networks.keys(): + new_net = cached_networks[name] + else: + new_net = None if net['type'] == 'generator': if new_net is None: new_net = networks.create_model(opt, net, self.netsG).to(self.device)