1
1
forked from mrq/tortoise-tts

Allow running on CPU

This commit is contained in:
Johan Nordberg 2022-06-11 20:03:14 +09:00
parent 5d96b486fb
commit 5c7a50820c
3 changed files with 35 additions and 31 deletions

View File

@ -101,7 +101,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
def format_conditioning(clip, cond_length=132300):
def format_conditioning(clip, cond_length=132300, device='cuda'):
"""
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
"""
@ -112,7 +112,7 @@ def format_conditioning(clip, cond_length=132300):
rand_start = random.randint(0, gap)
clip = clip[:, rand_start:rand_start + cond_length]
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).cuda()
return mel_clip.unsqueeze(0).to(device)
def fix_autoregressive_output(codes, stop_token, complain=True):
@ -181,14 +181,15 @@ def pick_best_batch_size_for_gpu():
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
you a good shot.
"""
free, available = torch.cuda.mem_get_info()
availableGb = available / (1024 ** 3)
if availableGb > 14:
return 16
elif availableGb > 10:
return 8
elif availableGb > 7:
return 4
if torch.cuda.is_available():
_, available = torch.cuda.mem_get_info()
availableGb = available / (1024 ** 3)
if availableGb > 14:
return 16
elif availableGb > 10:
return 8
elif availableGb > 7:
return 4
return 1
@ -197,7 +198,7 @@ class TextToSpeech:
Main entry point into Tortoise.
"""
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True):
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None):
"""
Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -207,10 +208,12 @@ class TextToSpeech:
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
(but are still rendered by the model). This can be used for prompt engineering.
Default is true.
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
"""
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if self.enable_redaction:
self.aligner = Wav2VecAlignment()
@ -240,7 +243,7 @@ class TextToSpeech:
self.cvvp = None # CVVP model is only loaded if used.
self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir))['model_g'])
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
self.vocoder.eval(inference=True)
# Random latent generators (RLGs) are loaded lazily.
@ -261,15 +264,15 @@ class TextToSpeech:
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
"""
with torch.no_grad():
voice_samples = [v.to('cuda') for v in voice_samples]
voice_samples = [v.to(self.device) for v in voice_samples]
auto_conds = []
if not isinstance(voice_samples, list):
voice_samples = [voice_samples]
for vs in voice_samples:
auto_conds.append(format_conditioning(vs))
auto_conds.append(format_conditioning(vs, self.device))
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = self.autoregressive.cuda()
self.autoregressive = self.autoregressive.to(self.device)
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu()
@ -278,11 +281,11 @@ class TextToSpeech:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = self.diffusion.cuda()
self.diffusion = self.diffusion.to(self.device)
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu()
@ -380,7 +383,7 @@ class TextToSpeech:
"""
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
@ -391,8 +394,8 @@ class TextToSpeech:
auto_conditioning, diffusion_conditioning = conditioning_latents
else:
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
auto_conditioning = auto_conditioning.cuda()
diffusion_conditioning = diffusion_conditioning.cuda()
auto_conditioning = auto_conditioning.to(self.device)
diffusion_conditioning = diffusion_conditioning.to(self.device)
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
@ -401,7 +404,7 @@ class TextToSpeech:
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.stop_mel_token
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
self.autoregressive = self.autoregressive.cuda()
self.autoregressive = self.autoregressive.to(self.device)
if verbose:
print("Generating autoregressive samples..")
for b in tqdm(range(num_batches), disable=not verbose):
@ -420,11 +423,11 @@ class TextToSpeech:
self.autoregressive = self.autoregressive.cpu()
clip_results = []
self.clvp = self.clvp.cuda()
self.clvp = self.clvp.to(self.device)
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.cuda()
self.cvvp = self.cvvp.to(self.device)
if verbose:
if self.cvvp is None:
print("Computing best candidates using CLVP")
@ -457,7 +460,7 @@ class TextToSpeech:
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
self.autoregressive = self.autoregressive.cuda()
self.autoregressive = self.autoregressive.to(self.device)
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
@ -468,8 +471,8 @@ class TextToSpeech:
if verbose:
print("Transforming autoregressive outputs into audio..")
wav_candidates = []
self.diffusion = self.diffusion.cuda()
self.vocoder = self.vocoder.cuda()
self.diffusion = self.diffusion.to(self.device)
self.vocoder = self.vocoder.to(self.device)
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)

View File

@ -180,9 +180,9 @@ class TacotronSTFT(torch.nn.Module):
return mel_output
def wav_to_univnet_mel(wav, do_normalization=False):
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'):
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
stft = stft.cuda()
stft = stft.to(device)
mel = stft.mel_spectrogram(wav)
if do_normalization:
mel = normalize_tacotron_mel(mel)

View File

@ -49,17 +49,18 @@ class Wav2VecAlignment:
"""
Uses wav2vec2 to perform audio<->text alignment.
"""
def __init__(self):
def __init__(self, device='cuda'):
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
self.device = device
def align(self, audio, expected_text, audio_sample_rate=24000):
orig_len = audio.shape[-1]
with torch.no_grad():
self.model = self.model.cuda()
audio = audio.to('cuda')
self.model = self.model.to(self.device)
audio = audio.to(self.device)
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
logits = self.model(clip_norm).logits