forked from mrq/tortoise-tts
Automatically pick batch size based on available GPU memory
This commit is contained in:
parent
cb7adf16af
commit
50690e4465
|
@ -160,12 +160,28 @@ def classify_audio_clip(clip):
|
||||||
return results[0][0]
|
return results[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
def pick_best_batch_size_for_gpu():
|
||||||
|
"""
|
||||||
|
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
|
||||||
|
you a good shot.
|
||||||
|
"""
|
||||||
|
free, available = torch.cuda.mem_get_info()
|
||||||
|
availableGb = available / (1024 ** 3)
|
||||||
|
if availableGb > 14:
|
||||||
|
return 16
|
||||||
|
elif availableGb > 10:
|
||||||
|
return 8
|
||||||
|
elif availableGb > 7:
|
||||||
|
return 4
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeech:
|
class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
|
def __init__(self, autoregressive_batch_size=None, models_dir='.models', enable_redaction=True):
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||||
|
@ -176,7 +192,7 @@ class TextToSpeech:
|
||||||
(but are still rendered by the model). This can be used for prompt engineering.
|
(but are still rendered by the model). This can be used for prompt engineering.
|
||||||
Default is true.
|
Default is true.
|
||||||
"""
|
"""
|
||||||
self.autoregressive_batch_size = autoregressive_batch_size
|
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||||
self.enable_redaction = enable_redaction
|
self.enable_redaction = enable_redaction
|
||||||
if self.enable_redaction:
|
if self.enable_redaction:
|
||||||
self.aligner = Wav2VecAlignment()
|
self.aligner = Wav2VecAlignment()
|
||||||
|
|
|
@ -148,6 +148,7 @@ def english_cleaners(text):
|
||||||
text = text.replace('"', '')
|
text = text.replace('"', '')
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def lev_distance(s1, s2):
|
def lev_distance(s1, s2):
|
||||||
if len(s1) > len(s2):
|
if len(s1) > len(s2):
|
||||||
s1, s2 = s2, s1
|
s1, s2 = s2, s1
|
||||||
|
@ -163,6 +164,7 @@ def lev_distance(s1, s2):
|
||||||
distances = distances_
|
distances = distances_
|
||||||
return distances[-1]
|
return distances[-1]
|
||||||
|
|
||||||
|
|
||||||
class VoiceBpeTokenizer:
|
class VoiceBpeTokenizer:
|
||||||
def __init__(self, vocab_file='tortoise/data/tokenizer.json'):
|
def __init__(self, vocab_file='tortoise/data/tokenizer.json'):
|
||||||
if vocab_file is not None:
|
if vocab_file is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user