Automatically pick batch size based on available GPU memory

This commit is contained in:
James Betker 2022-05-13 10:30:02 -06:00
parent 556172281d
commit 0570034eda
2 changed files with 20 additions and 2 deletions

View File

@ -160,12 +160,28 @@ def classify_audio_clip(clip):
return results[0][0] return results[0][0]
def pick_best_batch_size_for_gpu():
"""
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
you a good shot.
"""
free, available = torch.cuda.mem_get_info()
availableGb = available / (1024 ** 3)
if availableGb > 14:
return 16
elif availableGb > 10:
return 8
elif availableGb > 7:
return 4
return 1
class TextToSpeech: class TextToSpeech:
""" """
Main entry point into Tortoise. Main entry point into Tortoise.
""" """
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True): def __init__(self, autoregressive_batch_size=None, models_dir='.models', enable_redaction=True):
""" """
Constructor Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -176,7 +192,7 @@ class TextToSpeech:
(but are still rendered by the model). This can be used for prompt engineering. (but are still rendered by the model). This can be used for prompt engineering.
Default is true. Default is true.
""" """
self.autoregressive_batch_size = autoregressive_batch_size self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction self.enable_redaction = enable_redaction
if self.enable_redaction: if self.enable_redaction:
self.aligner = Wav2VecAlignment() self.aligner = Wav2VecAlignment()

View File

@ -148,6 +148,7 @@ def english_cleaners(text):
text = text.replace('"', '') text = text.replace('"', '')
return text return text
def lev_distance(s1, s2): def lev_distance(s1, s2):
if len(s1) > len(s2): if len(s1) > len(s2):
s1, s2 = s2, s1 s1, s2 = s2, s1
@ -163,6 +164,7 @@ def lev_distance(s1, s2):
distances = distances_ distances = distances_
return distances[-1] return distances[-1]
class VoiceBpeTokenizer: class VoiceBpeTokenizer:
def __init__(self, vocab_file='tortoise/data/tokenizer.json'): def __init__(self, vocab_file='tortoise/data/tokenizer.json'):
if vocab_file is not None: if vocab_file is not None: