Don't let duplicate keys be used for injectors and losses

This commit is contained in:
James Betker 2020-09-29 16:59:44 -06:00
parent 0b5a033503
commit dc8f3b24de

View File

@ -27,13 +27,17 @@ class ConfigurableStep(Module):
self.injectors = [] self.injectors = []
if 'injectors' in self.step_opt.keys(): if 'injectors' in self.step_opt.keys():
injector_names = []
for inj_name, injector in self.step_opt['injectors'].items(): for inj_name, injector in self.step_opt['injectors'].items():
assert inj_name not in injector_names # Repeated names are always an error case.
injector_names.append(inj_name)
self.injectors.append(create_injector(injector, env)) self.injectors.append(create_injector(injector, env))
losses = [] losses = []
self.weights = {} self.weights = {}
if 'losses' in self.step_opt.keys(): if 'losses' in self.step_opt.keys():
for loss_name, loss in self.step_opt['losses'].items(): for loss_name, loss in self.step_opt['losses'].items():
assert loss_name not in self.weights.keys() # Repeated names are always an error case.
losses.append((loss_name, create_loss(loss, env))) losses.append((loss_name, create_loss(loss, env)))
self.weights[loss_name] = loss['weight'] self.weights[loss_name] = loss['weight']
self.losses = OrderedDict(losses) self.losses = OrderedDict(losses)