From 1f54bf5b40a540c894ad2c3e06f957de14a72170 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 7 Dec 2024 17:09:39 -0600 Subject: [PATCH] revert sageattn back to optional dependency because it's not on windows, force resize_modules on by default because I broke something --- setup.py | 155 ++++++++++++++++++++++++----------------------- vall_e/config.py | 2 +- 2 files changed, 81 insertions(+), 76 deletions(-) diff --git a/setup.py b/setup.py index 4d255ce..830af79 100755 --- a/setup.py +++ b/setup.py @@ -5,94 +5,99 @@ from datetime import datetime from setuptools import setup, find_packages def shell(*args): - out = subprocess.check_output(args) - return out.decode("ascii").strip() + out = subprocess.check_output(args) + return out.decode("ascii").strip() def write_version(version_core, pre_release=True): - if pre_release: - time = shell("git", "log", "-1", "--format=%cd", "--date=iso") - time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z") - time = time.strftime("%Y%m%d%H%M%S") - version = f"{version_core}-dev{time}" - else: - version = version_core + if pre_release: + time = shell("git", "log", "-1", "--format=%cd", "--date=iso") + time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z") + time = time.strftime("%Y%m%d%H%M%S") + version = f"{version_core}-dev{time}" + else: + version = version_core - with open(Path("vall_e", "version.py"), "w") as f: - f.write('__version__ = "{}"\n'.format(version)) + with open(Path("vall_e", "version.py"), "w") as f: + f.write('__version__ = "{}"\n'.format(version)) - return version + return version with open("README.md", "r") as f: - long_description = f.read() + long_description = f.read() + +platform_dependencies = [] + +if sys.platform.startswith("win"): + platform_dependencies += ["psutil"] +else: + platform_dependencies += ["deepspeed>=0.7.7"] setup( - name="vall-e", - python_requires=">=3.10.0", - version=write_version("0.0.1"), - description="An unofficial implementation of the audio LM VALL-E", - author="ecker", - author_email="mrq@ecker.tech", - long_description=long_description, - long_description_content_type="text/markdown", - packages=find_packages(), - install_requires=( - # training backends - ["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else ["psutil"]) - + [ - # logging niceties - "coloredlogs>=15.0.1", # barely required - "humanize>=4.4.0", # not really required - "matplotlib>=3.6.0", # only required for plotting - "pandas>=1.5.0", # not really required + name="vall-e", + python_requires=">=3.10.0", + version=write_version("0.0.1"), + description="An unofficial implementation of the audio LM VALL-E", + author="ecker", + author_email="mrq@ecker.tech", + long_description=long_description, + long_description_content_type="text/markdown", + packages=find_packages(), + install_requires= + platform_dependencies + [ + # logging niceties + "coloredlogs>=15.0.1", # barely required + "humanize>=4.4.0", # not really required + "matplotlib>=3.6.0", # only required for plotting + "pandas>=1.5.0", # not really required - # boiler plate niceties - #"diskcache>=5.4.0", - "einops>=0.6.0", # could be replaced - "tqdm", + # boiler plate niceties + #"diskcache>=5.4.0", + "einops>=0.6.0", # could be replaced + "tqdm", - # HF bloat - "tokenizers", - "transformers", - "safetensors", + # HF bloat + "tokenizers", + "transformers", + "safetensors", - # training bloat - "auraloss[all]", # [all] is needed for MelSTFTLoss - "h5py", - "prodigyopt @ git+https://github.com/konstmish/prodigy", + # training bloat + "auraloss[all]", # [all] is needed for MelSTFTLoss + "h5py", + "prodigyopt @ git+https://github.com/konstmish/prodigy", - # practically the reason to use python - "numpy", - "torch>=1.13.0", - "torchaudio>=0.13.0", - "torchmetrics", + # practically the reason to use python + "numpy", + "torch>=1.13.0", + "torchaudio>=0.13.0", + "torchmetrics", - # core foundations - "phonemizer>=2.1.0", - "encodec>=0.1.1", - "vocos", + # core foundations + "phonemizer>=2.1.0", + "encodec>=0.1.1", + "vocos", - # for the web UI - "gradio", - "nltk", - "sageattention==1.0.6", - ], - extras_require = { - "all": [ - # retnet backend (even though two internal copies exist) - "torchscale @ git+https://git.ecker.tech/mrq/torchscale", - # bitnet - "bitnet", - # mamba - "causal-conv1d", - "mamba-ssm", + # for the web UI + "gradio", + "nltk", + ], + extras_require = { + "all": [ + # retnet backend (even though two internal copies exist) + "torchscale @ git+https://git.ecker.tech/mrq/torchscale", + # bitnet + "bitnet", + # mamba + "causal-conv1d", + "mamba-ssm", - # attention helpers - "xformers", - # "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it - - # other audio backend that doesn't prove fruitful - "descript-audio-codec", - ] - }, - url="https://git.ecker.tech/mrq/vall-e", + # attention helpers + "xformers", + "sageattention==1.0.6", + # "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it + + # other audio backend that doesn't prove fruitful + "descript-audio-codec", + ] + }, + url="https://git.ecker.tech/mrq/vall-e", ) diff --git a/vall_e/config.py b/vall_e/config.py index b79d086..6fd6547 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -698,7 +698,7 @@ class Trainer: strict_loading: bool = False # sets strict_loading=True when loading the state dict load_module_only: bool = False # restart_step_count: bool = False # clears the training stats when loading a checkpoint - resize_modules: bool = False # automatically resizes + resize_modules: bool = True # automatically resizes activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training