sane download bar

This commit is contained in:
mrq 2024-10-27 09:02:48 -05:00
parent a05faf0dfa
commit d1e811e6ea
4 changed files with 9 additions and 20 deletions

2
.gitignore vendored
View File

@ -3,6 +3,6 @@ __pycache__
/training
/venv
/*.egg-info
/vall_e/version.py
/tortoise_tts/version.py
/.cache
/voices

View File

@ -97,7 +97,7 @@ optimizations:
embedding: False
optimizers: True
bitsandbytes: True
bitsandbytes: False
dadaptation: False
bitnet: False
fp8: False

View File

@ -53,6 +53,9 @@ setup(
# HF bloat
"tokenizers",
"transformers",
"inflect",
"unidecode",
"vector_quantize_pytorch",
#
"rotary_embedding_torch",

View File

@ -46,21 +46,7 @@ DEFAULT_MODEL_URLS = {
# kludge, probably better to use HF's model downloader function
# to-do: write to a temp file then copy so downloads can be interrupted
def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
scale = 1
if unit == "KiB":
scale = (1024)
elif unit == "MiB":
scale = (1024 * 1024)
elif unit == "MiB":
scale = (1024 * 1024 * 1024)
elif unit == "KB":
scale = (1000)
elif unit == "MB":
scale = (1000 * 1000)
elif unit == "MB":
scale = (1000 * 1000 * 1000)
def download_model( save_path, chunkSize = 1024 ):
name = save_path.name
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
if url is None:
@ -70,15 +56,15 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
save_path.parent.mkdir(parents=True, exist_ok=True)
r = requests.get(url, stream=True)
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length'])
with open(save_path, 'wb') as f:
bar = tqdm( unit=unit, total=content_length )
bar = tqdm( unit='B', unit_scale=True, unit_divisor=1024, total=content_length, desc=f"Downloading: {name}" )
for chunk in r.iter_content(chunk_size=chunkSize):
if not chunk:
continue
bar.update( len(chunk) / scale )
bar.update( len(chunk) )
f.write(chunk)
bar.close()