sane download bar
This commit is contained in:
parent
a05faf0dfa
commit
d1e811e6ea
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -3,6 +3,6 @@ __pycache__
|
||||||
/training
|
/training
|
||||||
/venv
|
/venv
|
||||||
/*.egg-info
|
/*.egg-info
|
||||||
/vall_e/version.py
|
/tortoise_tts/version.py
|
||||||
/.cache
|
/.cache
|
||||||
/voices
|
/voices
|
||||||
|
|
|
@ -97,7 +97,7 @@ optimizations:
|
||||||
embedding: False
|
embedding: False
|
||||||
optimizers: True
|
optimizers: True
|
||||||
|
|
||||||
bitsandbytes: True
|
bitsandbytes: False
|
||||||
dadaptation: False
|
dadaptation: False
|
||||||
bitnet: False
|
bitnet: False
|
||||||
fp8: False
|
fp8: False
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -53,6 +53,9 @@ setup(
|
||||||
# HF bloat
|
# HF bloat
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"transformers",
|
"transformers",
|
||||||
|
"inflect",
|
||||||
|
"unidecode",
|
||||||
|
"vector_quantize_pytorch",
|
||||||
|
|
||||||
#
|
#
|
||||||
"rotary_embedding_torch",
|
"rotary_embedding_torch",
|
||||||
|
|
|
@ -46,21 +46,7 @@ DEFAULT_MODEL_URLS = {
|
||||||
|
|
||||||
# kludge, probably better to use HF's model downloader function
|
# kludge, probably better to use HF's model downloader function
|
||||||
# to-do: write to a temp file then copy so downloads can be interrupted
|
# to-do: write to a temp file then copy so downloads can be interrupted
|
||||||
def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
def download_model( save_path, chunkSize = 1024 ):
|
||||||
scale = 1
|
|
||||||
if unit == "KiB":
|
|
||||||
scale = (1024)
|
|
||||||
elif unit == "MiB":
|
|
||||||
scale = (1024 * 1024)
|
|
||||||
elif unit == "MiB":
|
|
||||||
scale = (1024 * 1024 * 1024)
|
|
||||||
elif unit == "KB":
|
|
||||||
scale = (1000)
|
|
||||||
elif unit == "MB":
|
|
||||||
scale = (1000 * 1000)
|
|
||||||
elif unit == "MB":
|
|
||||||
scale = (1000 * 1000 * 1000)
|
|
||||||
|
|
||||||
name = save_path.name
|
name = save_path.name
|
||||||
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
|
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
|
||||||
if url is None:
|
if url is None:
|
||||||
|
@ -70,15 +56,15 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
r = requests.get(url, stream=True)
|
r = requests.get(url, stream=True)
|
||||||
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale
|
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length'])
|
||||||
|
|
||||||
with open(save_path, 'wb') as f:
|
with open(save_path, 'wb') as f:
|
||||||
bar = tqdm( unit=unit, total=content_length )
|
bar = tqdm( unit='B', unit_scale=True, unit_divisor=1024, total=content_length, desc=f"Downloading: {name}" )
|
||||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||||
if not chunk:
|
if not chunk:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
bar.update( len(chunk) / scale )
|
bar.update( len(chunk) )
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
bar.close()
|
bar.close()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user