forked from mrq/tortoise-tts
disable BNB for inferencing by default because I'm pretty sure it makes zero differences (can be force enabled with env vars if you'r erelying on this for some reason)
This commit is contained in:
parent
f025470d60
commit
2f7d9ab932
|
@ -82,7 +82,9 @@ def get_device_vram( name=get_device_name() ):
|
|||
return available / (1024 ** 3)
|
||||
|
||||
def get_device_batch_size(name=None):
|
||||
name = get_device_name()
|
||||
if not name:
|
||||
name = get_device_name()
|
||||
|
||||
vram = get_device_vram(name)
|
||||
|
||||
if vram > 14:
|
||||
|
|
|
@ -22,17 +22,19 @@ import os
|
|||
|
||||
USE_STABLE_EMBEDDING = False
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
OVERRIDE_LINEAR = False
|
||||
OVERRIDE_EMBEDDING = True
|
||||
OVERRIDE_ADAM = True
|
||||
OVERRIDE_ADAMW = True
|
||||
OVERRIDE_EMBEDDING = False
|
||||
OVERRIDE_ADAM = False
|
||||
OVERRIDE_ADAMW = False
|
||||
|
||||
USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1'
|
||||
OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1'
|
||||
OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1'
|
||||
OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1'
|
||||
OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1'
|
||||
|
||||
if OVERRIDE_LINEAR or OVERRIDE_EMBEDDING or OVERRIDE_ADAM or OVERRIDE_ADAMW:
|
||||
import bitsandbytes as bnb
|
||||
except Exception as e:
|
||||
OVERRIDE_LINEAR = False
|
||||
OVERRIDE_EMBEDDING = False
|
||||
|
|
Loading…
Reference in New Issue
Block a user