Merge pull request #97 from jnordberg/cpu-support

CPU support
This commit is contained in:
James Betker 2022-06-12 23:12:03 -06:00 committed by GitHub
commit 29c1d9e561
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 33 deletions

View File

@ -79,6 +79,12 @@ advanced_group.add_argument(
help='Normally text enclosed in brackets are automatically redacted from the spoken output ' help='Normally text enclosed in brackets are automatically redacted from the spoken output '
'(but are still rendered by the model), this can be used for prompt engineering. ' '(but are still rendered by the model), this can be used for prompt engineering. '
'Set this to disable this behavior.') 'Set this to disable this behavior.')
advanced_group.add_argument(
'--device', type=str, default=None,
help='Device to use for inference.')
advanced_group.add_argument(
'--batch-size', type=int, default=None,
help='Batch size to use for inference. If omitted, the batch size is set based on available GPU memory.')
tuning_group = parser.add_argument_group('tuning options (overrides preset settings)') tuning_group = parser.add_argument_group('tuning options (overrides preset settings)')
tuning_group.add_argument( tuning_group.add_argument(
@ -200,10 +206,11 @@ if args.play:
seed = int(time.time()) if args.seed is None else args.seed seed = int(time.time()) if args.seed is None else args.seed
if not args.quiet: if not args.quiet:
print('Loading tts...') print('Loading tts...')
tts = TextToSpeech(models_dir=args.models_dir, enable_redaction=not args.disable_redaction) tts = TextToSpeech(models_dir=args.models_dir, enable_redaction=not args.disable_redaction,
device=args.device, autoregressive_batch_size=args.batch_size)
gen_settings = { gen_settings = {
'use_deterministic_seed': seed, 'use_deterministic_seed': seed,
'varbose': not args.quiet, 'verbose': not args.quiet,
'k': args.candidates, 'k': args.candidates,
'preset': args.preset, 'preset': args.preset,
} }

View File

@ -101,7 +101,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
conditioning_free=cond_free, conditioning_free_k=cond_free_k) conditioning_free=cond_free, conditioning_free_k=cond_free_k)
def format_conditioning(clip, cond_length=132300): def format_conditioning(clip, cond_length=132300, device='cuda'):
""" """
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
""" """
@ -112,7 +112,7 @@ def format_conditioning(clip, cond_length=132300):
rand_start = random.randint(0, gap) rand_start = random.randint(0, gap)
clip = clip[:, rand_start:rand_start + cond_length] clip = clip[:, rand_start:rand_start + cond_length]
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0) mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).cuda() return mel_clip.unsqueeze(0).to(device)
def fix_autoregressive_output(codes, stop_token, complain=True): def fix_autoregressive_output(codes, stop_token, complain=True):
@ -181,7 +181,8 @@ def pick_best_batch_size_for_gpu():
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
you a good shot. you a good shot.
""" """
free, available = torch.cuda.mem_get_info() if torch.cuda.is_available():
_, available = torch.cuda.mem_get_info()
availableGb = available / (1024 ** 3) availableGb = available / (1024 ** 3)
if availableGb > 14: if availableGb > 14:
return 16 return 16
@ -197,7 +198,7 @@ class TextToSpeech:
Main entry point into Tortoise. Main entry point into Tortoise.
""" """
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True): def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None):
""" """
Constructor Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -207,10 +208,12 @@ class TextToSpeech:
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
(but are still rendered by the model). This can be used for prompt engineering. (but are still rendered by the model). This can be used for prompt engineering.
Default is true. Default is true.
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
""" """
self.models_dir = models_dir self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction self.enable_redaction = enable_redaction
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if self.enable_redaction: if self.enable_redaction:
self.aligner = Wav2VecAlignment() self.aligner = Wav2VecAlignment()
@ -240,7 +243,7 @@ class TextToSpeech:
self.cvvp = None # CVVP model is only loaded if used. self.cvvp = None # CVVP model is only loaded if used.
self.vocoder = UnivNetGenerator().cpu() self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir))['model_g']) self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
self.vocoder.eval(inference=True) self.vocoder.eval(inference=True)
# Random latent generators (RLGs) are loaded lazily. # Random latent generators (RLGs) are loaded lazily.
@ -261,15 +264,15 @@ class TextToSpeech:
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
""" """
with torch.no_grad(): with torch.no_grad():
voice_samples = [v.to('cuda') for v in voice_samples] voice_samples = [v.to(self.device) for v in voice_samples]
auto_conds = [] auto_conds = []
if not isinstance(voice_samples, list): if not isinstance(voice_samples, list):
voice_samples = [voice_samples] voice_samples = [voice_samples]
for vs in voice_samples: for vs in voice_samples:
auto_conds.append(format_conditioning(vs)) auto_conds.append(format_conditioning(vs, device=self.device))
auto_conds = torch.stack(auto_conds, dim=1) auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = self.autoregressive.cuda() self.autoregressive = self.autoregressive.to(self.device)
auto_latent = self.autoregressive.get_conditioning(auto_conds) auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu() self.autoregressive = self.autoregressive.cpu()
@ -278,11 +281,11 @@ class TextToSpeech:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs) # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000) sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400) sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False) cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel) diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1) diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = self.diffusion.cuda() self.diffusion = self.diffusion.to(self.device)
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu() self.diffusion = self.diffusion.cpu()
@ -380,7 +383,7 @@ class TextToSpeech:
""" """
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda() text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
@ -391,8 +394,8 @@ class TextToSpeech:
auto_conditioning, diffusion_conditioning = conditioning_latents auto_conditioning, diffusion_conditioning = conditioning_latents
else: else:
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents() auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
auto_conditioning = auto_conditioning.cuda() auto_conditioning = auto_conditioning.to(self.device)
diffusion_conditioning = diffusion_conditioning.cuda() diffusion_conditioning = diffusion_conditioning.to(self.device)
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
@ -401,7 +404,7 @@ class TextToSpeech:
num_batches = num_autoregressive_samples // self.autoregressive_batch_size num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.stop_mel_token stop_mel_token = self.autoregressive.stop_mel_token
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
self.autoregressive = self.autoregressive.cuda() self.autoregressive = self.autoregressive.to(self.device)
if verbose: if verbose:
print("Generating autoregressive samples..") print("Generating autoregressive samples..")
for b in tqdm(range(num_batches), disable=not verbose): for b in tqdm(range(num_batches), disable=not verbose):
@ -420,11 +423,11 @@ class TextToSpeech:
self.autoregressive = self.autoregressive.cpu() self.autoregressive = self.autoregressive.cpu()
clip_results = [] clip_results = []
self.clvp = self.clvp.cuda() self.clvp = self.clvp.to(self.device)
if cvvp_amount > 0: if cvvp_amount > 0:
if self.cvvp is None: if self.cvvp is None:
self.load_cvvp() self.load_cvvp()
self.cvvp = self.cvvp.cuda() self.cvvp = self.cvvp.to(self.device)
if verbose: if verbose:
if self.cvvp is None: if self.cvvp is None:
print("Computing best candidates using CLVP") print("Computing best candidates using CLVP")
@ -457,7 +460,7 @@ class TextToSpeech:
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage. # results, but will increase memory usage.
self.autoregressive = self.autoregressive.cuda() self.autoregressive = self.autoregressive.to(self.device)
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
@ -468,8 +471,8 @@ class TextToSpeech:
if verbose: if verbose:
print("Transforming autoregressive outputs into audio..") print("Transforming autoregressive outputs into audio..")
wav_candidates = [] wav_candidates = []
self.diffusion = self.diffusion.cuda() self.diffusion = self.diffusion.to(self.device)
self.vocoder = self.vocoder.cuda() self.vocoder = self.vocoder.to(self.device)
for b in range(best_results.shape[0]): for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0) codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0) latents = best_latents[b].unsqueeze(0)

View File

@ -180,9 +180,9 @@ class TacotronSTFT(torch.nn.Module):
return mel_output return mel_output
def wav_to_univnet_mel(wav, do_normalization=False): def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'):
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000) stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
stft = stft.cuda() stft = stft.to(device)
mel = stft.mel_spectrogram(wav) mel = stft.mel_spectrogram(wav)
if do_normalization: if do_normalization:
mel = normalize_tacotron_mel(mel) mel = normalize_tacotron_mel(mel)

View File

@ -49,17 +49,18 @@ class Wav2VecAlignment:
""" """
Uses wav2vec2 to perform audio<->text alignment. Uses wav2vec2 to perform audio<->text alignment.
""" """
def __init__(self): def __init__(self, device='cuda'):
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols') self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
self.device = device
def align(self, audio, expected_text, audio_sample_rate=24000): def align(self, audio, expected_text, audio_sample_rate=24000):
orig_len = audio.shape[-1] orig_len = audio.shape[-1]
with torch.no_grad(): with torch.no_grad():
self.model = self.model.cuda() self.model = self.model.to(self.device)
audio = audio.to('cuda') audio = audio.to(self.device)
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
logits = self.model(clip_norm).logits logits = self.model(clip_norm).logits