forked from mrq/tortoise-tts
Optimizing graphics card memory
During inference it does not store gradients, which take up most of the video memory
This commit is contained in:
parent
ea8c825ee0
commit
cc38333249
|
@ -225,30 +225,31 @@ class TextToSpeech:
|
||||||
properties.
|
properties.
|
||||||
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
|
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
|
||||||
"""
|
"""
|
||||||
voice_samples = [v.to('cuda') for v in voice_samples]
|
with torch.no_grad():
|
||||||
|
voice_samples = [v.to('cuda') for v in voice_samples]
|
||||||
|
|
||||||
auto_conds = []
|
auto_conds = []
|
||||||
if not isinstance(voice_samples, list):
|
if not isinstance(voice_samples, list):
|
||||||
voice_samples = [voice_samples]
|
voice_samples = [voice_samples]
|
||||||
for vs in voice_samples:
|
for vs in voice_samples:
|
||||||
auto_conds.append(format_conditioning(vs))
|
auto_conds.append(format_conditioning(vs))
|
||||||
auto_conds = torch.stack(auto_conds, dim=1)
|
auto_conds = torch.stack(auto_conds, dim=1)
|
||||||
self.autoregressive = self.autoregressive.cuda()
|
self.autoregressive = self.autoregressive.cuda()
|
||||||
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = self.autoregressive.cpu()
|
||||||
|
|
||||||
diffusion_conds = []
|
diffusion_conds = []
|
||||||
for sample in voice_samples:
|
for sample in voice_samples:
|
||||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||||
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
||||||
sample = pad_or_truncate(sample, 102400)
|
sample = pad_or_truncate(sample, 102400)
|
||||||
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
|
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
|
||||||
diffusion_conds.append(cond_mel)
|
diffusion_conds.append(cond_mel)
|
||||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||||
|
|
||||||
self.diffusion = self.diffusion.cuda()
|
self.diffusion = self.diffusion.cuda()
|
||||||
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
||||||
self.diffusion = self.diffusion.cpu()
|
self.diffusion = self.diffusion.cpu()
|
||||||
|
|
||||||
if return_mels:
|
if return_mels:
|
||||||
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
|
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
|
||||||
|
|
Loading…
Reference in New Issue
Block a user