From 535549c3f3b75956c457e7014a541636331b5e37 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 17 Feb 2023 19:20:19 +0000 Subject: [PATCH] add some snark about the kludge I had to fix, and the kludge I used to fix it --- codes/train.py | 2 ++ codes/trainer/inject.py | 1 + codes/trainer/networks.py | 6 ++++++ 3 files changed, 9 insertions(+) diff --git a/codes/train.py b/codes/train.py index b9aa82c4..cf63dda2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -110,6 +110,8 @@ class Trainer: self.dataset_debugger = get_dataset_debugger(dataset_opt) if self.dataset_debugger is not None and resume_state is not None: self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {})) + # it will indefinitely try to train if your batch size is larger than your dataset + # could just whine when generating the YAML rather than assert here if len(self.train_set) <= dataset_opt['batch_size']: raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.") train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size'])) diff --git a/codes/trainer/inject.py b/codes/trainer/inject.py index b536d13e..da521419 100644 --- a/codes/trainer/inject.py +++ b/codes/trainer/inject.py @@ -34,6 +34,7 @@ def format_injector_name(name): # Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector. # field will be properly populated. def find_registered_injectors(base_path="trainer/injectors"): + # this has the same modification networks.py has, so be sure to mirror it path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}')) module_iter = pkgutil.walk_packages([path]) results = {} diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index 3108f26e..63a63f34 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -30,6 +30,12 @@ def register_model(func): func._dlas_registered_model = True return func +# this had some weird kludge that I don't understand needing to have a reference frame around the current working directory +# it works better when you set it relative to this file instead +# however, this has very different behavior when importing DLAS from outside the repo, rather than spawning a shell instance to a script within it +# I can't be assed to deal with that headache at the moment, I just want something to work right now without needing to touch a shell + +# inject.py has a similar loader scheme, be sure to mirror it if you touch this too def find_registered_model_fns(base_path='models'): found_fns = {} path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))