From 94d0f16608ecb787c0ddeb65e8472edc6edfec76 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 17 Feb 2023 02:03:00 +0000 Subject: [PATCH] Necessary fixes to get it to work --- .gitignore | 0 MANIFEST.in | 0 README.old.md | 0 codes/__init__.py | 0 codes/trainer/inject.py | 4 +++- codes/trainer/networks.py | 8 +++----- experiments/EXAMPLE_gpt.yml | 0 experiments/bpe_lowercase_asr_256.json | 0 experiments/train_diffusion_vocoder_22k_level.yml | 0 requirements.txt | 0 setup.py | 14 ++++---------- 11 files changed, 10 insertions(+), 16 deletions(-) mode change 100755 => 100644 .gitignore mode change 100755 => 100644 MANIFEST.in mode change 100755 => 100644 README.old.md mode change 100755 => 100644 codes/__init__.py mode change 100755 => 100644 experiments/EXAMPLE_gpt.yml mode change 100755 => 100644 experiments/bpe_lowercase_asr_256.json mode change 100755 => 100644 experiments/train_diffusion_vocoder_22k_level.yml mode change 100755 => 100644 requirements.txt mode change 100755 => 100644 setup.py diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 diff --git a/MANIFEST.in b/MANIFEST.in old mode 100755 new mode 100644 diff --git a/README.old.md b/README.old.md old mode 100755 new mode 100644 diff --git a/codes/__init__.py b/codes/__init__.py old mode 100755 new mode 100644 diff --git a/codes/trainer/inject.py b/codes/trainer/inject.py index a37f9351..b536d13e 100644 --- a/codes/trainer/inject.py +++ b/codes/trainer/inject.py @@ -3,6 +3,7 @@ import inspect import pkgutil import re import sys +import os import torch.nn @@ -33,7 +34,8 @@ def format_injector_name(name): # Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector. # field will be properly populated. def find_registered_injectors(base_path="trainer/injectors"): - module_iter = pkgutil.walk_packages([base_path]) + path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}')) + module_iter = pkgutil.walk_packages([path]) results = {} for mod in module_iter: if mod.ispkg: diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index 723389ff..3108f26e 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -30,14 +30,12 @@ def register_model(func): func._dlas_registered_model = True return func - def find_registered_model_fns(base_path='models'): found_fns = {} - module_iter = pkgutil.walk_packages([base_path]) + path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}')) + + module_iter = pkgutil.walk_packages([path]) for mod in module_iter: - if os.name == 'nt': - if os.path.join(os.getcwd(), base_path) not in mod.module_finder.path: - continue # I have no idea why this is necessary - I think it's a bug in the latest PyWindows release. if mod.ispkg: EXCLUSION_LIST = ['flownet2'] if mod.name not in EXCLUSION_LIST: diff --git a/experiments/EXAMPLE_gpt.yml b/experiments/EXAMPLE_gpt.yml old mode 100755 new mode 100644 diff --git a/experiments/bpe_lowercase_asr_256.json b/experiments/bpe_lowercase_asr_256.json old mode 100755 new mode 100644 diff --git a/experiments/train_diffusion_vocoder_22k_level.yml b/experiments/train_diffusion_vocoder_22k_level.yml old mode 100755 new mode 100644 diff --git a/requirements.txt b/requirements.txt old mode 100755 new mode 100644 diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index cba4ff04..16c19702 --- a/setup.py +++ b/setup.py @@ -1,19 +1,13 @@ import setuptools +from pip.req import parse_requirements + with open("README.old.md", "r", encoding="utf-8") as fh: long_description = fh.read() -# kludge -packages = setuptools.find_packages() -for i in range(len(packages)): - packages[i] = packages[i].replace("codes", "dlas") - setuptools.setup( name="DL-Art-School", - packages=packages, - package_dir={ - "dlas": "./codes" - }, + packages=setuptools.find_packages(), version="0.0.1", author="James Betker", author_email="james@adamant.ai", @@ -24,7 +18,7 @@ setuptools.setup( project_urls={}, scripts=[], include_package_data=True, - install_requires=[], + install_requires=parse_requirements('requirements.txt', session='hack'), classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License",