Merge pull request from e0xextazy/main

Optimizing graphics card memory
This commit is contained in:
James Betker 2022-05-11 21:46:16 -06:00 committed by GitHub
commit 945bd88f21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -225,30 +225,31 @@ class TextToSpeech:
properties.
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
"""
voice_samples = [v.to('cuda') for v in voice_samples]
with torch.no_grad():
voice_samples = [v.to('cuda') for v in voice_samples]
auto_conds = []
if not isinstance(voice_samples, list):
voice_samples = [voice_samples]
for vs in voice_samples:
auto_conds.append(format_conditioning(vs))
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = self.autoregressive.cuda()
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu()
auto_conds = []
if not isinstance(voice_samples, list):
voice_samples = [voice_samples]
for vs in voice_samples:
auto_conds.append(format_conditioning(vs))
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = self.autoregressive.cuda()
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu()
diffusion_conds = []
for sample in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
diffusion_conds = []
for sample in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = self.diffusion.cuda()
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu()
self.diffusion = self.diffusion.cuda()
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu()
if return_mels:
return auto_latent, diffusion_latent, auto_conds, diffusion_conds