diff --git a/tortoise/api.py b/tortoise/api.py index 3b63549..dbd22cf 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -242,21 +242,17 @@ class TextToSpeech: in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, layer_drop=0, unconditioned_percentage=0).cpu().eval() self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir))) - self.autoregressive = self.autoregressive.to(self.device) - self.diffusion = self.diffusion.to(self.device) self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20, text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430, use_xformers=True).cpu().eval() self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir))) - self.clvp = self.clvp.to(self.device) self.cvvp = None # CVVP model is only loaded if used. self.vocoder = UnivNetGenerator().cpu() self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g']) self.vocoder.eval(inference=True) - self.vocoder = self.vocoder.to(self.device) # Random latent generators (RLGs) are loaded lazily. self.rlg_auto = None @@ -267,7 +263,6 @@ class TextToSpeech: self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval() self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir))) - self.cvvp = self.cvvp.to(self.device) def get_conditioning_latents(self, voice_samples, return_mels=False): """ @@ -285,7 +280,9 @@ class TextToSpeech: for vs in voice_samples: auto_conds.append(format_conditioning(vs, device=self.device)) auto_conds = torch.stack(auto_conds, dim=1) + self.autoregressive = self.autoregressive.to(self.device) auto_latent = self.autoregressive.get_conditioning(auto_conds) + self.autoregressive = self.autoregressive.cpu() diffusion_conds = [] for sample in voice_samples: @@ -296,7 +293,9 @@ class TextToSpeech: diffusion_conds.append(cond_mel) diffusion_conds = torch.stack(diffusion_conds, dim=1) + self.diffusion = self.diffusion.to(self.device) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) + self.diffusion = self.diffusion.cpu() if return_mels: return auto_latent, diffusion_latent, auto_conds, diffusion_conds @@ -414,7 +413,8 @@ class TextToSpeech: num_batches = num_autoregressive_samples // self.autoregressive_batch_size stop_mel_token = self.autoregressive.stop_mel_token calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" - + self.autoregressive = self.autoregressive.to(self.device) + for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"): codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, do_sample=True, @@ -428,12 +428,15 @@ class TextToSpeech: padding_needed = max_mel_tokens - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) samples.append(codes) + self.autoregressive = self.autoregressive.cpu() clip_results = [] + self.clvp = self.clvp.to(self.device) if cvvp_amount > 0: if self.cvvp is None: self.load_cvvp() - + self.cvvp = self.cvvp.to(self.device) + desc="Computing best candidates" if verbose: if self.cvvp is None: @@ -460,18 +463,25 @@ class TextToSpeech: clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) best_results = samples[torch.topk(clip_results, k=k).indices] + self.clvp = self.clvp.cpu() + if self.cvvp is not None: + self.cvvp = self.cvvp.cpu() del samples # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. + self.autoregressive = self.autoregressive.to(self.device) best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), return_latent=True, clip_inputs=False) + self.autoregressive = self.autoregressive.cpu() del auto_conditioning wav_candidates = [] + self.diffusion = self.diffusion.to(self.device) + self.vocoder = self.vocoder.to(self.device) for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) latents = best_latents[b].unsqueeze(0) @@ -491,6 +501,8 @@ class TextToSpeech: temperature=diffusion_temperature, verbose=verbose, progress=progress, desc="Transforming autoregressive outputs into audio..") wav = self.vocoder.inference(mel) wav_candidates.append(wav.cpu()) + self.diffusion = self.diffusion.cpu() + self.vocoder = self.vocoder.cpu() def potentially_redact(clip, text): if self.enable_redaction: