forked from mrq/DL-Art-School
Revert recent changes to extr
This commit is contained in:
parent
83cccef9d8
commit
b404a3b747
|
@ -21,7 +21,7 @@ logger = logging.getLogger('base')
|
|||
|
||||
|
||||
class ExtensibleTrainer(BaseModel):
|
||||
def __init__(self, opt):
|
||||
def __init__(self, opt, cached_networks={}):
|
||||
super(ExtensibleTrainer, self).__init__(opt)
|
||||
if opt['dist']:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
|
@ -54,6 +54,10 @@ class ExtensibleTrainer(BaseModel):
|
|||
if 'trainable' not in net.keys():
|
||||
net['trainable'] = True
|
||||
|
||||
if name in cached_networks.keys():
|
||||
new_net = cached_networks[name]
|
||||
else:
|
||||
new_net = None
|
||||
if net['type'] == 'generator':
|
||||
if new_net is None:
|
||||
new_net = networks.create_model(opt, net, self.netsG).to(self.device)
|
||||
|
|
Loading…
Reference in New Issue
Block a user