add some snark about the kludge I had to fix, and the kludge I used to fix it
This commit is contained in:
parent
a09cf98c7f
commit
535549c3f3
|
@ -110,6 +110,8 @@ class Trainer:
|
|||
self.dataset_debugger = get_dataset_debugger(dataset_opt)
|
||||
if self.dataset_debugger is not None and resume_state is not None:
|
||||
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
||||
# it will indefinitely try to train if your batch size is larger than your dataset
|
||||
# could just whine when generating the YAML rather than assert here
|
||||
if len(self.train_set) <= dataset_opt['batch_size']:
|
||||
raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
|
||||
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
||||
|
|
|
@ -34,6 +34,7 @@ def format_injector_name(name):
|
|||
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
|
||||
# field will be properly populated.
|
||||
def find_registered_injectors(base_path="trainer/injectors"):
|
||||
# this has the same modification networks.py has, so be sure to mirror it
|
||||
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
|
||||
module_iter = pkgutil.walk_packages([path])
|
||||
results = {}
|
||||
|
|
|
@ -30,6 +30,12 @@ def register_model(func):
|
|||
func._dlas_registered_model = True
|
||||
return func
|
||||
|
||||
# this had some weird kludge that I don't understand needing to have a reference frame around the current working directory
|
||||
# it works better when you set it relative to this file instead
|
||||
# however, this has very different behavior when importing DLAS from outside the repo, rather than spawning a shell instance to a script within it
|
||||
# I can't be assed to deal with that headache at the moment, I just want something to work right now without needing to touch a shell
|
||||
|
||||
# inject.py has a similar loader scheme, be sure to mirror it if you touch this too
|
||||
def find_registered_model_fns(base_path='models'):
|
||||
found_fns = {}
|
||||
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
|
||||
|
|
Loading…
Reference in New Issue
Block a user