forked from mrq/DL-Art-School
add some snark about the kludge I had to fix, and the kludge I used to fix it
This commit is contained in:
parent
a09cf98c7f
commit
535549c3f3
|
@ -110,6 +110,8 @@ class Trainer:
|
||||||
self.dataset_debugger = get_dataset_debugger(dataset_opt)
|
self.dataset_debugger = get_dataset_debugger(dataset_opt)
|
||||||
if self.dataset_debugger is not None and resume_state is not None:
|
if self.dataset_debugger is not None and resume_state is not None:
|
||||||
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
||||||
|
# it will indefinitely try to train if your batch size is larger than your dataset
|
||||||
|
# could just whine when generating the YAML rather than assert here
|
||||||
if len(self.train_set) <= dataset_opt['batch_size']:
|
if len(self.train_set) <= dataset_opt['batch_size']:
|
||||||
raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
|
raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
|
||||||
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
||||||
|
|
|
@ -34,6 +34,7 @@ def format_injector_name(name):
|
||||||
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
|
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
|
||||||
# field will be properly populated.
|
# field will be properly populated.
|
||||||
def find_registered_injectors(base_path="trainer/injectors"):
|
def find_registered_injectors(base_path="trainer/injectors"):
|
||||||
|
# this has the same modification networks.py has, so be sure to mirror it
|
||||||
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
|
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
|
||||||
module_iter = pkgutil.walk_packages([path])
|
module_iter = pkgutil.walk_packages([path])
|
||||||
results = {}
|
results = {}
|
||||||
|
|
|
@ -30,6 +30,12 @@ def register_model(func):
|
||||||
func._dlas_registered_model = True
|
func._dlas_registered_model = True
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
# this had some weird kludge that I don't understand needing to have a reference frame around the current working directory
|
||||||
|
# it works better when you set it relative to this file instead
|
||||||
|
# however, this has very different behavior when importing DLAS from outside the repo, rather than spawning a shell instance to a script within it
|
||||||
|
# I can't be assed to deal with that headache at the moment, I just want something to work right now without needing to touch a shell
|
||||||
|
|
||||||
|
# inject.py has a similar loader scheme, be sure to mirror it if you touch this too
|
||||||
def find_registered_model_fns(base_path='models'):
|
def find_registered_model_fns(base_path='models'):
|
||||||
found_fns = {}
|
found_fns = {}
|
||||||
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
|
path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user