forked from mrq/DL-Art-School
Fix trainable not applying to discriminators
This commit is contained in:
parent
21ae135f23
commit
b1238d29cb
|
@ -39,12 +39,12 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.netsD = {}
|
||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
||||
for name, net in opt['networks'].items():
|
||||
# Trainable is a required parameter, but the default is simply true. Set it here.
|
||||
if 'trainable' not in net.keys():
|
||||
net['trainable'] = True
|
||||
|
||||
if net['type'] == 'generator':
|
||||
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
||||
if 'trainable' not in net.keys():
|
||||
net['trainable'] = True
|
||||
if not net['trainable']:
|
||||
new_net.eval()
|
||||
self.netsG[name] = new_net
|
||||
elif net['type'] == 'discriminator':
|
||||
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
|
||||
|
@ -52,6 +52,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
else:
|
||||
raise NotImplementedError("Can only handle generators and discriminators")
|
||||
|
||||
if not net['trainable']:
|
||||
new_net.eval()
|
||||
|
||||
# Initialize the train/eval steps
|
||||
self.steps = []
|
||||
for step_name, step in opt['steps'].items():
|
||||
|
|
Loading…
Reference in New Issue
Block a user