forked from mrq/tortoise-tts
Allow running on CPU
This commit is contained in:
parent
5d96b486fb
commit
5c7a50820c
|
@ -101,7 +101,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
|
||||||
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
|
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
|
||||||
|
|
||||||
|
|
||||||
def format_conditioning(clip, cond_length=132300):
|
def format_conditioning(clip, cond_length=132300, device='cuda'):
|
||||||
"""
|
"""
|
||||||
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
|
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
|
||||||
"""
|
"""
|
||||||
|
@ -112,7 +112,7 @@ def format_conditioning(clip, cond_length=132300):
|
||||||
rand_start = random.randint(0, gap)
|
rand_start = random.randint(0, gap)
|
||||||
clip = clip[:, rand_start:rand_start + cond_length]
|
clip = clip[:, rand_start:rand_start + cond_length]
|
||||||
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
|
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
|
||||||
return mel_clip.unsqueeze(0).cuda()
|
return mel_clip.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
|
||||||
def fix_autoregressive_output(codes, stop_token, complain=True):
|
def fix_autoregressive_output(codes, stop_token, complain=True):
|
||||||
|
@ -181,14 +181,15 @@ def pick_best_batch_size_for_gpu():
|
||||||
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
|
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
|
||||||
you a good shot.
|
you a good shot.
|
||||||
"""
|
"""
|
||||||
free, available = torch.cuda.mem_get_info()
|
if torch.cuda.is_available():
|
||||||
availableGb = available / (1024 ** 3)
|
_, available = torch.cuda.mem_get_info()
|
||||||
if availableGb > 14:
|
availableGb = available / (1024 ** 3)
|
||||||
return 16
|
if availableGb > 14:
|
||||||
elif availableGb > 10:
|
return 16
|
||||||
return 8
|
elif availableGb > 10:
|
||||||
elif availableGb > 7:
|
return 8
|
||||||
return 4
|
elif availableGb > 7:
|
||||||
|
return 4
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -197,7 +198,7 @@ class TextToSpeech:
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True):
|
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None):
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||||
|
@ -207,10 +208,12 @@ class TextToSpeech:
|
||||||
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
|
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
|
||||||
(but are still rendered by the model). This can be used for prompt engineering.
|
(but are still rendered by the model). This can be used for prompt engineering.
|
||||||
Default is true.
|
Default is true.
|
||||||
|
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
|
||||||
"""
|
"""
|
||||||
self.models_dir = models_dir
|
self.models_dir = models_dir
|
||||||
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||||
self.enable_redaction = enable_redaction
|
self.enable_redaction = enable_redaction
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
if self.enable_redaction:
|
if self.enable_redaction:
|
||||||
self.aligner = Wav2VecAlignment()
|
self.aligner = Wav2VecAlignment()
|
||||||
|
|
||||||
|
@ -240,7 +243,7 @@ class TextToSpeech:
|
||||||
self.cvvp = None # CVVP model is only loaded if used.
|
self.cvvp = None # CVVP model is only loaded if used.
|
||||||
|
|
||||||
self.vocoder = UnivNetGenerator().cpu()
|
self.vocoder = UnivNetGenerator().cpu()
|
||||||
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir))['model_g'])
|
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
|
||||||
self.vocoder.eval(inference=True)
|
self.vocoder.eval(inference=True)
|
||||||
|
|
||||||
# Random latent generators (RLGs) are loaded lazily.
|
# Random latent generators (RLGs) are loaded lazily.
|
||||||
|
@ -261,15 +264,15 @@ class TextToSpeech:
|
||||||
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
|
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
voice_samples = [v.to('cuda') for v in voice_samples]
|
voice_samples = [v.to(self.device) for v in voice_samples]
|
||||||
|
|
||||||
auto_conds = []
|
auto_conds = []
|
||||||
if not isinstance(voice_samples, list):
|
if not isinstance(voice_samples, list):
|
||||||
voice_samples = [voice_samples]
|
voice_samples = [voice_samples]
|
||||||
for vs in voice_samples:
|
for vs in voice_samples:
|
||||||
auto_conds.append(format_conditioning(vs))
|
auto_conds.append(format_conditioning(vs, self.device))
|
||||||
auto_conds = torch.stack(auto_conds, dim=1)
|
auto_conds = torch.stack(auto_conds, dim=1)
|
||||||
self.autoregressive = self.autoregressive.cuda()
|
self.autoregressive = self.autoregressive.to(self.device)
|
||||||
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = self.autoregressive.cpu()
|
||||||
|
|
||||||
|
@ -278,11 +281,11 @@ class TextToSpeech:
|
||||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||||
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
||||||
sample = pad_or_truncate(sample, 102400)
|
sample = pad_or_truncate(sample, 102400)
|
||||||
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
|
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device)
|
||||||
diffusion_conds.append(cond_mel)
|
diffusion_conds.append(cond_mel)
|
||||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||||
|
|
||||||
self.diffusion = self.diffusion.cuda()
|
self.diffusion = self.diffusion.to(self.device)
|
||||||
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
||||||
self.diffusion = self.diffusion.cpu()
|
self.diffusion = self.diffusion.cpu()
|
||||||
|
|
||||||
|
@ -380,7 +383,7 @@ class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
|
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
|
||||||
|
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
||||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||||
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
||||||
|
|
||||||
|
@ -391,8 +394,8 @@ class TextToSpeech:
|
||||||
auto_conditioning, diffusion_conditioning = conditioning_latents
|
auto_conditioning, diffusion_conditioning = conditioning_latents
|
||||||
else:
|
else:
|
||||||
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
|
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
|
||||||
auto_conditioning = auto_conditioning.cuda()
|
auto_conditioning = auto_conditioning.to(self.device)
|
||||||
diffusion_conditioning = diffusion_conditioning.cuda()
|
diffusion_conditioning = diffusion_conditioning.to(self.device)
|
||||||
|
|
||||||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
||||||
|
|
||||||
|
@ -401,7 +404,7 @@ class TextToSpeech:
|
||||||
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
|
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
|
||||||
stop_mel_token = self.autoregressive.stop_mel_token
|
stop_mel_token = self.autoregressive.stop_mel_token
|
||||||
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||||
self.autoregressive = self.autoregressive.cuda()
|
self.autoregressive = self.autoregressive.to(self.device)
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Generating autoregressive samples..")
|
print("Generating autoregressive samples..")
|
||||||
for b in tqdm(range(num_batches), disable=not verbose):
|
for b in tqdm(range(num_batches), disable=not verbose):
|
||||||
|
@ -420,11 +423,11 @@ class TextToSpeech:
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = self.autoregressive.cpu()
|
||||||
|
|
||||||
clip_results = []
|
clip_results = []
|
||||||
self.clvp = self.clvp.cuda()
|
self.clvp = self.clvp.to(self.device)
|
||||||
if cvvp_amount > 0:
|
if cvvp_amount > 0:
|
||||||
if self.cvvp is None:
|
if self.cvvp is None:
|
||||||
self.load_cvvp()
|
self.load_cvvp()
|
||||||
self.cvvp = self.cvvp.cuda()
|
self.cvvp = self.cvvp.to(self.device)
|
||||||
if verbose:
|
if verbose:
|
||||||
if self.cvvp is None:
|
if self.cvvp is None:
|
||||||
print("Computing best candidates using CLVP")
|
print("Computing best candidates using CLVP")
|
||||||
|
@ -457,7 +460,7 @@ class TextToSpeech:
|
||||||
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
||||||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||||
# results, but will increase memory usage.
|
# results, but will increase memory usage.
|
||||||
self.autoregressive = self.autoregressive.cuda()
|
self.autoregressive = self.autoregressive.to(self.device)
|
||||||
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||||
|
@ -468,8 +471,8 @@ class TextToSpeech:
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Transforming autoregressive outputs into audio..")
|
print("Transforming autoregressive outputs into audio..")
|
||||||
wav_candidates = []
|
wav_candidates = []
|
||||||
self.diffusion = self.diffusion.cuda()
|
self.diffusion = self.diffusion.to(self.device)
|
||||||
self.vocoder = self.vocoder.cuda()
|
self.vocoder = self.vocoder.to(self.device)
|
||||||
for b in range(best_results.shape[0]):
|
for b in range(best_results.shape[0]):
|
||||||
codes = best_results[b].unsqueeze(0)
|
codes = best_results[b].unsqueeze(0)
|
||||||
latents = best_latents[b].unsqueeze(0)
|
latents = best_latents[b].unsqueeze(0)
|
||||||
|
|
|
@ -180,9 +180,9 @@ class TacotronSTFT(torch.nn.Module):
|
||||||
return mel_output
|
return mel_output
|
||||||
|
|
||||||
|
|
||||||
def wav_to_univnet_mel(wav, do_normalization=False):
|
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'):
|
||||||
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
|
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
|
||||||
stft = stft.cuda()
|
stft = stft.to(device)
|
||||||
mel = stft.mel_spectrogram(wav)
|
mel = stft.mel_spectrogram(wav)
|
||||||
if do_normalization:
|
if do_normalization:
|
||||||
mel = normalize_tacotron_mel(mel)
|
mel = normalize_tacotron_mel(mel)
|
||||||
|
|
|
@ -49,17 +49,18 @@ class Wav2VecAlignment:
|
||||||
"""
|
"""
|
||||||
Uses wav2vec2 to perform audio<->text alignment.
|
Uses wav2vec2 to perform audio<->text alignment.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self, device='cuda'):
|
||||||
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
||||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
||||||
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
|
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def align(self, audio, expected_text, audio_sample_rate=24000):
|
def align(self, audio, expected_text, audio_sample_rate=24000):
|
||||||
orig_len = audio.shape[-1]
|
orig_len = audio.shape[-1]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model = self.model.cuda()
|
self.model = self.model.to(self.device)
|
||||||
audio = audio.to('cuda')
|
audio = audio.to(self.device)
|
||||||
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
|
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
|
||||||
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||||
logits = self.model(clip_norm).logits
|
logits = self.model(clip_norm).logits
|
||||||
|
|
Loading…
Reference in New Issue
Block a user