Fix trainable not applying to discriminators

This commit is contained in:
James Betker 2020-09-05 20:31:26 -06:00
parent 21ae135f23
commit b1238d29cb

View File

@ -39,12 +39,12 @@ class ExtensibleTrainer(BaseModel):
self.netsD = {}
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
for name, net in opt['networks'].items():
# Trainable is a required parameter, but the default is simply true. Set it here.
if 'trainable' not in net.keys():
net['trainable'] = True
if net['type'] == 'generator':
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
if 'trainable' not in net.keys():
net['trainable'] = True
if not net['trainable']:
new_net.eval()
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
@ -52,6 +52,9 @@ class ExtensibleTrainer(BaseModel):
else:
raise NotImplementedError("Can only handle generators and discriminators")
if not net['trainable']:
new_net.eval()
# Initialize the train/eval steps
self.steps = []
for step_name, step in opt['steps'].items():