revert sageattn back to optional dependency because it's not on windows, force resize_modules on by default because I broke something
This commit is contained in:
parent
218d0e29fd
commit
1f54bf5b40
153
setup.py
153
setup.py
|
@ -5,94 +5,99 @@ from datetime import datetime
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
def shell(*args):
|
def shell(*args):
|
||||||
out = subprocess.check_output(args)
|
out = subprocess.check_output(args)
|
||||||
return out.decode("ascii").strip()
|
return out.decode("ascii").strip()
|
||||||
|
|
||||||
def write_version(version_core, pre_release=True):
|
def write_version(version_core, pre_release=True):
|
||||||
if pre_release:
|
if pre_release:
|
||||||
time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
|
time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
|
||||||
time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z")
|
time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z")
|
||||||
time = time.strftime("%Y%m%d%H%M%S")
|
time = time.strftime("%Y%m%d%H%M%S")
|
||||||
version = f"{version_core}-dev{time}"
|
version = f"{version_core}-dev{time}"
|
||||||
else:
|
else:
|
||||||
version = version_core
|
version = version_core
|
||||||
|
|
||||||
with open(Path("vall_e", "version.py"), "w") as f:
|
with open(Path("vall_e", "version.py"), "w") as f:
|
||||||
f.write('__version__ = "{}"\n'.format(version))
|
f.write('__version__ = "{}"\n'.format(version))
|
||||||
|
|
||||||
return version
|
return version
|
||||||
|
|
||||||
with open("README.md", "r") as f:
|
with open("README.md", "r") as f:
|
||||||
long_description = f.read()
|
long_description = f.read()
|
||||||
|
|
||||||
|
platform_dependencies = []
|
||||||
|
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
platform_dependencies += ["psutil"]
|
||||||
|
else:
|
||||||
|
platform_dependencies += ["deepspeed>=0.7.7"]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="vall-e",
|
name="vall-e",
|
||||||
python_requires=">=3.10.0",
|
python_requires=">=3.10.0",
|
||||||
version=write_version("0.0.1"),
|
version=write_version("0.0.1"),
|
||||||
description="An unofficial implementation of the audio LM VALL-E",
|
description="An unofficial implementation of the audio LM VALL-E",
|
||||||
author="ecker",
|
author="ecker",
|
||||||
author_email="mrq@ecker.tech",
|
author_email="mrq@ecker.tech",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=(
|
install_requires=
|
||||||
# training backends
|
platform_dependencies + [
|
||||||
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else ["psutil"])
|
# logging niceties
|
||||||
+ [
|
"coloredlogs>=15.0.1", # barely required
|
||||||
# logging niceties
|
"humanize>=4.4.0", # not really required
|
||||||
"coloredlogs>=15.0.1", # barely required
|
"matplotlib>=3.6.0", # only required for plotting
|
||||||
"humanize>=4.4.0", # not really required
|
"pandas>=1.5.0", # not really required
|
||||||
"matplotlib>=3.6.0", # only required for plotting
|
|
||||||
"pandas>=1.5.0", # not really required
|
|
||||||
|
|
||||||
# boiler plate niceties
|
# boiler plate niceties
|
||||||
#"diskcache>=5.4.0",
|
#"diskcache>=5.4.0",
|
||||||
"einops>=0.6.0", # could be replaced
|
"einops>=0.6.0", # could be replaced
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
|
||||||
# HF bloat
|
# HF bloat
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"transformers",
|
"transformers",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
|
|
||||||
# training bloat
|
# training bloat
|
||||||
"auraloss[all]", # [all] is needed for MelSTFTLoss
|
"auraloss[all]", # [all] is needed for MelSTFTLoss
|
||||||
"h5py",
|
"h5py",
|
||||||
"prodigyopt @ git+https://github.com/konstmish/prodigy",
|
"prodigyopt @ git+https://github.com/konstmish/prodigy",
|
||||||
|
|
||||||
# practically the reason to use python
|
# practically the reason to use python
|
||||||
"numpy",
|
"numpy",
|
||||||
"torch>=1.13.0",
|
"torch>=1.13.0",
|
||||||
"torchaudio>=0.13.0",
|
"torchaudio>=0.13.0",
|
||||||
"torchmetrics",
|
"torchmetrics",
|
||||||
|
|
||||||
# core foundations
|
# core foundations
|
||||||
"phonemizer>=2.1.0",
|
"phonemizer>=2.1.0",
|
||||||
"encodec>=0.1.1",
|
"encodec>=0.1.1",
|
||||||
"vocos",
|
"vocos",
|
||||||
|
|
||||||
# for the web UI
|
# for the web UI
|
||||||
"gradio",
|
"gradio",
|
||||||
"nltk",
|
"nltk",
|
||||||
"sageattention==1.0.6",
|
],
|
||||||
],
|
extras_require = {
|
||||||
extras_require = {
|
"all": [
|
||||||
"all": [
|
# retnet backend (even though two internal copies exist)
|
||||||
# retnet backend (even though two internal copies exist)
|
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
|
||||||
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
|
# bitnet
|
||||||
# bitnet
|
"bitnet",
|
||||||
"bitnet",
|
# mamba
|
||||||
# mamba
|
"causal-conv1d",
|
||||||
"causal-conv1d",
|
"mamba-ssm",
|
||||||
"mamba-ssm",
|
|
||||||
|
|
||||||
# attention helpers
|
# attention helpers
|
||||||
"xformers",
|
"xformers",
|
||||||
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
|
"sageattention==1.0.6",
|
||||||
|
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
|
||||||
|
|
||||||
# other audio backend that doesn't prove fruitful
|
# other audio backend that doesn't prove fruitful
|
||||||
"descript-audio-codec",
|
"descript-audio-codec",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
url="https://git.ecker.tech/mrq/vall-e",
|
url="https://git.ecker.tech/mrq/vall-e",
|
||||||
)
|
)
|
||||||
|
|
|
@ -698,7 +698,7 @@ class Trainer:
|
||||||
strict_loading: bool = False # sets strict_loading=True when loading the state dict
|
strict_loading: bool = False # sets strict_loading=True when loading the state dict
|
||||||
load_module_only: bool = False #
|
load_module_only: bool = False #
|
||||||
restart_step_count: bool = False # clears the training stats when loading a checkpoint
|
restart_step_count: bool = False # clears the training stats when loading a checkpoint
|
||||||
resize_modules: bool = False # automatically resizes
|
resize_modules: bool = True # automatically resizes
|
||||||
|
|
||||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||||
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
||||||
|
|
Loading…
Reference in New Issue
Block a user