add some snark about the kludge I had to fix, and the kludge I used to fix it

This commit is contained in:
mrq 2023-02-17 19:20:19 +00:00
parent a09cf98c7f
commit 535549c3f3
3 changed files with 9 additions and 0 deletions

View File

@ -110,6 +110,8 @@ class Trainer:
self.dataset_debugger = get_dataset_debugger(dataset_opt)
if self.dataset_debugger is not None and resume_state is not None:
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
# it will indefinitely try to train if your batch size is larger than your dataset
# could just whine when generating the YAML rather than assert here
if len(self.train_set) <= dataset_opt['batch_size']:
raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))

View File

@ -34,6 +34,7 @@ def format_injector_name(name):
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
# field will be properly populated.
def find_registered_injectors(base_path="trainer/injectors"):
# this has the same modification networks.py has, so be sure to mirror it
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
module_iter = pkgutil.walk_packages([path])
results = {}

View File

@ -30,6 +30,12 @@ def register_model(func):
func._dlas_registered_model = True
return func
# this had some weird kludge that I don't understand needing to have a reference frame around the current working directory
# it works better when you set it relative to this file instead
# however, this has very different behavior when importing DLAS from outside the repo, rather than spawning a shell instance to a script within it
# I can't be assed to deal with that headache at the moment, I just want something to work right now without needing to touch a shell
# inject.py has a similar loader scheme, be sure to mirror it if you touch this too
def find_registered_model_fns(base_path='models'):
found_fns = {}
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))