Necessary fixes to get it to work

This commit is contained in:
mrq 2023-02-17 02:03:00 +00:00
parent 49e23b226b
commit 94d0f16608
11 changed files with 10 additions and 16 deletions

0
.gitignore vendored Executable file → Normal file
View File

0
MANIFEST.in Executable file → Normal file
View File

0
README.old.md Executable file → Normal file
View File

0
codes/__init__.py Executable file → Normal file
View File

View File

@ -3,6 +3,7 @@ import inspect
import pkgutil import pkgutil
import re import re
import sys import sys
import os
import torch.nn import torch.nn
@ -33,7 +34,8 @@ def format_injector_name(name):
# Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector. # Works by loading all python modules in the injectors/ directory and sniffing out subclasses of Injector.
# field will be properly populated. # field will be properly populated.
def find_registered_injectors(base_path="trainer/injectors"): def find_registered_injectors(base_path="trainer/injectors"):
module_iter = pkgutil.walk_packages([base_path]) path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
module_iter = pkgutil.walk_packages([path])
results = {} results = {}
for mod in module_iter: for mod in module_iter:
if mod.ispkg: if mod.ispkg:

View File

@ -30,14 +30,12 @@ def register_model(func):
func._dlas_registered_model = True func._dlas_registered_model = True
return func return func
def find_registered_model_fns(base_path='models'): def find_registered_model_fns(base_path='models'):
found_fns = {} found_fns = {}
module_iter = pkgutil.walk_packages([base_path]) path = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), f'../{base_path}'))
module_iter = pkgutil.walk_packages([path])
for mod in module_iter: for mod in module_iter:
if os.name == 'nt':
if os.path.join(os.getcwd(), base_path) not in mod.module_finder.path:
continue # I have no idea why this is necessary - I think it's a bug in the latest PyWindows release.
if mod.ispkg: if mod.ispkg:
EXCLUSION_LIST = ['flownet2'] EXCLUSION_LIST = ['flownet2']
if mod.name not in EXCLUSION_LIST: if mod.name not in EXCLUSION_LIST:

0
experiments/EXAMPLE_gpt.yml Executable file → Normal file
View File

0
experiments/bpe_lowercase_asr_256.json Executable file → Normal file
View File

0
experiments/train_diffusion_vocoder_22k_level.yml Executable file → Normal file
View File

0
requirements.txt Executable file → Normal file
View File

14
setup.py Executable file → Normal file
View File

@ -1,19 +1,13 @@
import setuptools import setuptools
from pip.req import parse_requirements
with open("README.old.md", "r", encoding="utf-8") as fh: with open("README.old.md", "r", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()
# kludge
packages = setuptools.find_packages()
for i in range(len(packages)):
packages[i] = packages[i].replace("codes", "dlas")
setuptools.setup( setuptools.setup(
name="DL-Art-School", name="DL-Art-School",
packages=packages, packages=setuptools.find_packages(),
package_dir={
"dlas": "./codes"
},
version="0.0.1", version="0.0.1",
author="James Betker", author="James Betker",
author_email="james@adamant.ai", author_email="james@adamant.ai",
@ -24,7 +18,7 @@ setuptools.setup(
project_urls={}, project_urls={},
scripts=[], scripts=[],
include_package_data=True, include_package_data=True,
install_requires=[], install_requires=parse_requirements('requirements.txt', session='hack'),
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",