Revert recent changes to extr

This commit is contained in:
James Betker 2021-10-30 20:48:06 -06:00
parent 83cccef9d8
commit b404a3b747

View File

@ -21,7 +21,7 @@ logger = logging.getLogger('base')
class ExtensibleTrainer(BaseModel):
def __init__(self, opt):
def __init__(self, opt, cached_networks={}):
super(ExtensibleTrainer, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
@ -54,6 +54,10 @@ class ExtensibleTrainer(BaseModel):
if 'trainable' not in net.keys():
net['trainable'] = True
if name in cached_networks.keys():
new_net = cached_networks[name]
else:
new_net = None
if net['type'] == 'generator':
if new_net is None:
new_net = networks.create_model(opt, net, self.netsG).to(self.device)