revert sageattn back to optional dependency because it's not on windows, force resize_modules on by default because I broke something

This commit is contained in:
mrq 2024-12-07 17:09:39 -06:00
parent 218d0e29fd
commit 1f54bf5b40
2 changed files with 81 additions and 76 deletions

155
setup.py
View File

@ -5,94 +5,99 @@ from datetime import datetime
from setuptools import setup, find_packages from setuptools import setup, find_packages
def shell(*args): def shell(*args):
out = subprocess.check_output(args) out = subprocess.check_output(args)
return out.decode("ascii").strip() return out.decode("ascii").strip()
def write_version(version_core, pre_release=True): def write_version(version_core, pre_release=True):
if pre_release: if pre_release:
time = shell("git", "log", "-1", "--format=%cd", "--date=iso") time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z") time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z")
time = time.strftime("%Y%m%d%H%M%S") time = time.strftime("%Y%m%d%H%M%S")
version = f"{version_core}-dev{time}" version = f"{version_core}-dev{time}"
else: else:
version = version_core version = version_core
with open(Path("vall_e", "version.py"), "w") as f: with open(Path("vall_e", "version.py"), "w") as f:
f.write('__version__ = "{}"\n'.format(version)) f.write('__version__ = "{}"\n'.format(version))
return version return version
with open("README.md", "r") as f: with open("README.md", "r") as f:
long_description = f.read() long_description = f.read()
platform_dependencies = []
if sys.platform.startswith("win"):
platform_dependencies += ["psutil"]
else:
platform_dependencies += ["deepspeed>=0.7.7"]
setup( setup(
name="vall-e", name="vall-e",
python_requires=">=3.10.0", python_requires=">=3.10.0",
version=write_version("0.0.1"), version=write_version("0.0.1"),
description="An unofficial implementation of the audio LM VALL-E", description="An unofficial implementation of the audio LM VALL-E",
author="ecker", author="ecker",
author_email="mrq@ecker.tech", author_email="mrq@ecker.tech",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
packages=find_packages(), packages=find_packages(),
install_requires=( install_requires=
# training backends platform_dependencies + [
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else ["psutil"]) # logging niceties
+ [ "coloredlogs>=15.0.1", # barely required
# logging niceties "humanize>=4.4.0", # not really required
"coloredlogs>=15.0.1", # barely required "matplotlib>=3.6.0", # only required for plotting
"humanize>=4.4.0", # not really required "pandas>=1.5.0", # not really required
"matplotlib>=3.6.0", # only required for plotting
"pandas>=1.5.0", # not really required
# boiler plate niceties # boiler plate niceties
#"diskcache>=5.4.0", #"diskcache>=5.4.0",
"einops>=0.6.0", # could be replaced "einops>=0.6.0", # could be replaced
"tqdm", "tqdm",
# HF bloat # HF bloat
"tokenizers", "tokenizers",
"transformers", "transformers",
"safetensors", "safetensors",
# training bloat # training bloat
"auraloss[all]", # [all] is needed for MelSTFTLoss "auraloss[all]", # [all] is needed for MelSTFTLoss
"h5py", "h5py",
"prodigyopt @ git+https://github.com/konstmish/prodigy", "prodigyopt @ git+https://github.com/konstmish/prodigy",
# practically the reason to use python # practically the reason to use python
"numpy", "numpy",
"torch>=1.13.0", "torch>=1.13.0",
"torchaudio>=0.13.0", "torchaudio>=0.13.0",
"torchmetrics", "torchmetrics",
# core foundations # core foundations
"phonemizer>=2.1.0", "phonemizer>=2.1.0",
"encodec>=0.1.1", "encodec>=0.1.1",
"vocos", "vocos",
# for the web UI # for the web UI
"gradio", "gradio",
"nltk", "nltk",
"sageattention==1.0.6", ],
], extras_require = {
extras_require = { "all": [
"all": [ # retnet backend (even though two internal copies exist)
# retnet backend (even though two internal copies exist) "torchscale @ git+https://git.ecker.tech/mrq/torchscale",
"torchscale @ git+https://git.ecker.tech/mrq/torchscale", # bitnet
# bitnet "bitnet",
"bitnet", # mamba
# mamba "causal-conv1d",
"causal-conv1d", "mamba-ssm",
"mamba-ssm",
# attention helpers # attention helpers
"xformers", "xformers",
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it "sageattention==1.0.6",
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
# other audio backend that doesn't prove fruitful
"descript-audio-codec", # other audio backend that doesn't prove fruitful
] "descript-audio-codec",
}, ]
url="https://git.ecker.tech/mrq/vall-e", },
url="https://git.ecker.tech/mrq/vall-e",
) )

View File

@ -698,7 +698,7 @@ class Trainer:
strict_loading: bool = False # sets strict_loading=True when loading the state dict strict_loading: bool = False # sets strict_loading=True when loading the state dict
load_module_only: bool = False # load_module_only: bool = False #
restart_step_count: bool = False # clears the training stats when loading a checkpoint restart_step_count: bool = False # clears the training stats when loading a checkpoint
resize_modules: bool = False # automatically resizes resize_modules: bool = True # automatically resizes
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training