forked from mrq/tortoise-tts
Automatically pick batch size based on available GPU memory
This commit is contained in:
parent
cb7adf16af
commit
50690e4465
|
@ -160,12 +160,28 @@ def classify_audio_clip(clip):
|
|||
return results[0][0]
|
||||
|
||||
|
||||
def pick_best_batch_size_for_gpu():
|
||||
"""
|
||||
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
|
||||
you a good shot.
|
||||
"""
|
||||
free, available = torch.cuda.mem_get_info()
|
||||
availableGb = available / (1024 ** 3)
|
||||
if availableGb > 14:
|
||||
return 16
|
||||
elif availableGb > 10:
|
||||
return 8
|
||||
elif availableGb > 7:
|
||||
return 4
|
||||
return 1
|
||||
|
||||
|
||||
class TextToSpeech:
|
||||
"""
|
||||
Main entry point into Tortoise.
|
||||
"""
|
||||
|
||||
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir='.models', enable_redaction=True):
|
||||
"""
|
||||
Constructor
|
||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||
|
@ -176,7 +192,7 @@ class TextToSpeech:
|
|||
(but are still rendered by the model). This can be used for prompt engineering.
|
||||
Default is true.
|
||||
"""
|
||||
self.autoregressive_batch_size = autoregressive_batch_size
|
||||
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||
self.enable_redaction = enable_redaction
|
||||
if self.enable_redaction:
|
||||
self.aligner = Wav2VecAlignment()
|
||||
|
|
|
@ -148,6 +148,7 @@ def english_cleaners(text):
|
|||
text = text.replace('"', '')
|
||||
return text
|
||||
|
||||
|
||||
def lev_distance(s1, s2):
|
||||
if len(s1) > len(s2):
|
||||
s1, s2 = s2, s1
|
||||
|
@ -163,6 +164,7 @@ def lev_distance(s1, s2):
|
|||
distances = distances_
|
||||
return distances[-1]
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file='tortoise/data/tokenizer.json'):
|
||||
if vocab_file is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user