Merge pull request #36 from e0xextazy/main

Optimizing graphics card memory
This commit is contained in:
James Betker 2022-05-11 21:46:16 -06:00 committed by GitHub
commit 5c60c5d4f2

View File

@ -225,30 +225,31 @@ class TextToSpeech:
properties. properties.
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
""" """
voice_samples = [v.to('cuda') for v in voice_samples] with torch.no_grad():
voice_samples = [v.to('cuda') for v in voice_samples]
auto_conds = [] auto_conds = []
if not isinstance(voice_samples, list): if not isinstance(voice_samples, list):
voice_samples = [voice_samples] voice_samples = [voice_samples]
for vs in voice_samples: for vs in voice_samples:
auto_conds.append(format_conditioning(vs)) auto_conds.append(format_conditioning(vs))
auto_conds = torch.stack(auto_conds, dim=1) auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = self.autoregressive.cuda() self.autoregressive = self.autoregressive.cuda()
auto_latent = self.autoregressive.get_conditioning(auto_conds) auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu() self.autoregressive = self.autoregressive.cpu()
diffusion_conds = [] diffusion_conds = []
for sample in voice_samples: for sample in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs) # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000) sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400) sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False) cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
diffusion_conds.append(cond_mel) diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1) diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = self.diffusion.cuda() self.diffusion = self.diffusion.cuda()
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu() self.diffusion = self.diffusion.cpu()
if return_mels: if return_mels:
return auto_latent, diffusion_latent, auto_conds, diffusion_conds return auto_latent, diffusion_latent, auto_conds, diffusion_conds