diff --git a/tortoise/api.py b/tortoise/api.py index 65c7d6e..9fff11f 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -225,30 +225,31 @@ class TextToSpeech: properties. :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. """ - voice_samples = [v.to('cuda') for v in voice_samples] + with torch.no_grad(): + voice_samples = [v.to('cuda') for v in voice_samples] - auto_conds = [] - if not isinstance(voice_samples, list): - voice_samples = [voice_samples] - for vs in voice_samples: - auto_conds.append(format_conditioning(vs)) - auto_conds = torch.stack(auto_conds, dim=1) - self.autoregressive = self.autoregressive.cuda() - auto_latent = self.autoregressive.get_conditioning(auto_conds) - self.autoregressive = self.autoregressive.cpu() + auto_conds = [] + if not isinstance(voice_samples, list): + voice_samples = [voice_samples] + for vs in voice_samples: + auto_conds.append(format_conditioning(vs)) + auto_conds = torch.stack(auto_conds, dim=1) + self.autoregressive = self.autoregressive.cuda() + auto_latent = self.autoregressive.get_conditioning(auto_conds) + self.autoregressive = self.autoregressive.cpu() - diffusion_conds = [] - for sample in voice_samples: - # The diffuser operates at a sample rate of 24000 (except for the latent inputs) - sample = torchaudio.functional.resample(sample, 22050, 24000) - sample = pad_or_truncate(sample, 102400) - cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False) - diffusion_conds.append(cond_mel) - diffusion_conds = torch.stack(diffusion_conds, dim=1) + diffusion_conds = [] + for sample in voice_samples: + # The diffuser operates at a sample rate of 24000 (except for the latent inputs) + sample = torchaudio.functional.resample(sample, 22050, 24000) + sample = pad_or_truncate(sample, 102400) + cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False) + diffusion_conds.append(cond_mel) + diffusion_conds = torch.stack(diffusion_conds, dim=1) - self.diffusion = self.diffusion.cuda() - diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) - self.diffusion = self.diffusion.cpu() + self.diffusion = self.diffusion.cuda() + diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) + self.diffusion = self.diffusion.cpu() if return_mels: return auto_latent, diffusion_latent, auto_conds, diffusion_conds