Fix trainable not applying to discriminators
This commit is contained in:
parent
21ae135f23
commit
b1238d29cb
|
@ -39,12 +39,12 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.netsD = {}
|
self.netsD = {}
|
||||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
||||||
for name, net in opt['networks'].items():
|
for name, net in opt['networks'].items():
|
||||||
|
# Trainable is a required parameter, but the default is simply true. Set it here.
|
||||||
|
if 'trainable' not in net.keys():
|
||||||
|
net['trainable'] = True
|
||||||
|
|
||||||
if net['type'] == 'generator':
|
if net['type'] == 'generator':
|
||||||
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
||||||
if 'trainable' not in net.keys():
|
|
||||||
net['trainable'] = True
|
|
||||||
if not net['trainable']:
|
|
||||||
new_net.eval()
|
|
||||||
self.netsG[name] = new_net
|
self.netsG[name] = new_net
|
||||||
elif net['type'] == 'discriminator':
|
elif net['type'] == 'discriminator':
|
||||||
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
|
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
|
||||||
|
@ -52,6 +52,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Can only handle generators and discriminators")
|
raise NotImplementedError("Can only handle generators and discriminators")
|
||||||
|
|
||||||
|
if not net['trainable']:
|
||||||
|
new_net.eval()
|
||||||
|
|
||||||
# Initialize the train/eval steps
|
# Initialize the train/eval steps
|
||||||
self.steps = []
|
self.steps = []
|
||||||
for step_name, step in opt['steps'].items():
|
for step_name, step in opt['steps'].items():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user