forked from mrq/DL-Art-School
Revert recent changes to extr
This commit is contained in:
parent
83cccef9d8
commit
b404a3b747
|
@ -21,7 +21,7 @@ logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
class ExtensibleTrainer(BaseModel):
|
class ExtensibleTrainer(BaseModel):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt, cached_networks={}):
|
||||||
super(ExtensibleTrainer, self).__init__(opt)
|
super(ExtensibleTrainer, self).__init__(opt)
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
|
@ -54,6 +54,10 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if 'trainable' not in net.keys():
|
if 'trainable' not in net.keys():
|
||||||
net['trainable'] = True
|
net['trainable'] = True
|
||||||
|
|
||||||
|
if name in cached_networks.keys():
|
||||||
|
new_net = cached_networks[name]
|
||||||
|
else:
|
||||||
|
new_net = None
|
||||||
if net['type'] == 'generator':
|
if net['type'] == 'generator':
|
||||||
if new_net is None:
|
if new_net is None:
|
||||||
new_net = networks.create_model(opt, net, self.netsG).to(self.device)
|
new_net = networks.create_model(opt, net, self.netsG).to(self.device)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user