add some snark about the kludge I had to fix, and the kludge I used to fix it

This commit is contained in:
mrq 2023-02-17 19:20:19 +00:00
parent a09cf98c7f
commit 535549c3f3
3 changed files with 9 additions and 0 deletions

View File

@ -110,6 +110,8 @@ class Trainer:
self.dataset_debugger = get_dataset_debugger(dataset_opt) self.dataset_debugger = get_dataset_debugger(dataset_opt)
if self.dataset_debugger is not None and resume_state is not None: if self.dataset_debugger is not None and resume_state is not None:
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {})) self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
# it will indefinitely try to train if your batch size is larger than your dataset
# could just whine when generating the YAML rather than assert here
if len(self.train_set) <= dataset_opt['batch_size']: if len(self.train_set) <= dataset_opt['batch_size']:
raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.") raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size'])) train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))

View File

@ -34,6 +34,7 @@ def format_injector_name(name):
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector. # Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
# field will be properly populated. # field will be properly populated.
def find_registered_injectors(base_path="trainer/injectors"): def find_registered_injectors(base_path="trainer/injectors"):
# this has the same modification networks.py has, so be sure to mirror it
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}')) path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
module_iter = pkgutil.walk_packages([path]) module_iter = pkgutil.walk_packages([path])
results = {} results = {}

View File

@ -30,6 +30,12 @@ def register_model(func):
func._dlas_registered_model = True func._dlas_registered_model = True
return func return func
# this had some weird kludge that I don't understand needing to have a reference frame around the current working directory
# it works better when you set it relative to this file instead
# however, this has very different behavior when importing DLAS from outside the repo, rather than spawning a shell instance to a script within it
# I can't be assed to deal with that headache at the moment, I just want something to work right now without needing to touch a shell
# inject.py has a similar loader scheme, be sure to mirror it if you touch this too
def find_registered_model_fns(base_path='models'): def find_registered_model_fns(base_path='models'):
found_fns = {} found_fns = {}
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}')) path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))