From cc3833324966ae3c09ac1a0b68878c1bae9be47a Mon Sep 17 00:00:00 2001
From: Mark Baushenko <e0xextazy@gmail.com>
Date: Wed, 11 May 2022 16:35:11 +0300
Subject: [PATCH] Optimizing graphics card memory

During inference it does not store gradients, which take up most of the video memory
---
 tortoise/api.py | 43 ++++++++++++++++++++++---------------------
 1 file changed, 22 insertions(+), 21 deletions(-)

diff --git a/tortoise/api.py b/tortoise/api.py
index 65c7d6e..9fff11f 100644
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -225,30 +225,31 @@ class TextToSpeech:
         properties.
         :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
         """
-        voice_samples = [v.to('cuda') for v in voice_samples]
+        with torch.no_grad():
+            voice_samples = [v.to('cuda') for v in voice_samples]
 
-        auto_conds = []
-        if not isinstance(voice_samples, list):
-            voice_samples = [voice_samples]
-        for vs in voice_samples:
-            auto_conds.append(format_conditioning(vs))
-        auto_conds = torch.stack(auto_conds, dim=1)
-        self.autoregressive = self.autoregressive.cuda()
-        auto_latent = self.autoregressive.get_conditioning(auto_conds)
-        self.autoregressive = self.autoregressive.cpu()
+            auto_conds = []
+            if not isinstance(voice_samples, list):
+                voice_samples = [voice_samples]
+            for vs in voice_samples:
+                auto_conds.append(format_conditioning(vs))
+            auto_conds = torch.stack(auto_conds, dim=1)
+            self.autoregressive = self.autoregressive.cuda()
+            auto_latent = self.autoregressive.get_conditioning(auto_conds)
+            self.autoregressive = self.autoregressive.cpu()
 
-        diffusion_conds = []
-        for sample in voice_samples:
-            # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
-            sample = torchaudio.functional.resample(sample, 22050, 24000)
-            sample = pad_or_truncate(sample, 102400)
-            cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
-            diffusion_conds.append(cond_mel)
-        diffusion_conds = torch.stack(diffusion_conds, dim=1)
+            diffusion_conds = []
+            for sample in voice_samples:
+                # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
+                sample = torchaudio.functional.resample(sample, 22050, 24000)
+                sample = pad_or_truncate(sample, 102400)
+                cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
+                diffusion_conds.append(cond_mel)
+            diffusion_conds = torch.stack(diffusion_conds, dim=1)
 
-        self.diffusion = self.diffusion.cuda()
-        diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
-        self.diffusion = self.diffusion.cpu()
+            self.diffusion = self.diffusion.cuda()
+            diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
+            self.diffusion = self.diffusion.cpu()
 
         if return_mels:
             return auto_latent, diffusion_latent, auto_conds, diffusion_conds