revert sageattn back to optional dependency because it's not on windows, force resize_modules on by default because I broke something

This commit is contained in:
mrq 2024-12-07 17:09:39 -06:00
parent 218d0e29fd
commit 1f54bf5b40
2 changed files with 81 additions and 76 deletions

View File

@ -25,6 +25,13 @@ def write_version(version_core, pre_release=True):
with open("README.md", "r") as f: with open("README.md", "r") as f:
long_description = f.read() long_description = f.read()
platform_dependencies = []
if sys.platform.startswith("win"):
platform_dependencies += ["psutil"]
else:
platform_dependencies += ["deepspeed>=0.7.7"]
setup( setup(
name="vall-e", name="vall-e",
python_requires=">=3.10.0", python_requires=">=3.10.0",
@ -35,10 +42,8 @@ setup(
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
packages=find_packages(), packages=find_packages(),
install_requires=( install_requires=
# training backends platform_dependencies + [
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else ["psutil"])
+ [
# logging niceties # logging niceties
"coloredlogs>=15.0.1", # barely required "coloredlogs>=15.0.1", # barely required
"humanize>=4.4.0", # not really required "humanize>=4.4.0", # not really required
@ -74,7 +79,6 @@ setup(
# for the web UI # for the web UI
"gradio", "gradio",
"nltk", "nltk",
"sageattention==1.0.6",
], ],
extras_require = { extras_require = {
"all": [ "all": [
@ -88,6 +92,7 @@ setup(
# attention helpers # attention helpers
"xformers", "xformers",
"sageattention==1.0.6",
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it # "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
# other audio backend that doesn't prove fruitful # other audio backend that doesn't prove fruitful

View File

@ -698,7 +698,7 @@ class Trainer:
strict_loading: bool = False # sets strict_loading=True when loading the state dict strict_loading: bool = False # sets strict_loading=True when loading the state dict
load_module_only: bool = False # load_module_only: bool = False #
restart_step_count: bool = False # clears the training stats when loading a checkpoint restart_step_count: bool = False # clears the training stats when loading a checkpoint
resize_modules: bool = False # automatically resizes resize_modules: bool = True # automatically resizes
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training