From 97cd58e7eba06b91fc6b0c77029ad72022b8c231 Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Sun, 12 Mar 2023 12:48:29 -0500
Subject: [PATCH] maybe solved that odd VRAM spike when doing the clvp pass

---
 tortoise/api.py          | 145 +++++++++++++++++++++------------------
 tortoise/models/clvp.py  |  15 ++--
 tortoise/utils/device.py |  23 +++++++
 3 files changed, 108 insertions(+), 75 deletions(-)

diff --git a/tortoise/api.py b/tortoise/api.py
index b663fa6..2df47ea 100755
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -29,7 +29,7 @@ from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named
 from tortoise.utils.tokenizer import VoiceBpeTokenizer
 from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
 
-from tortoise.utils.device import get_device, get_device_name, get_device_batch_size
+from tortoise.utils.device import get_device, get_device_name, get_device_batch_size, print_stats, do_gc
 
 pbar = None
 STOP_SIGNAL = False
@@ -172,8 +172,8 @@ def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=2
         rand_start = random.randint(0, gap)
         clip = clip[:, rand_start:rand_start + cond_length]
     mel_clip = TorchMelSpectrogram(sampling_rate=sampling_rate)(clip.unsqueeze(0)).squeeze(0)
-    return mel_clip.unsqueeze(0).to(device)
-
+    mel_clip = mel_clip.unsqueeze(0)
+    return migrate_to_device(mel_clip, device)
 
 def fix_autoregressive_output(codes, stop_token, complain=True):
     """
@@ -241,6 +241,25 @@ def classify_audio_clip(clip):
     results = F.softmax(classifier(clip), dim=-1)
     return results[0][0]
 
+def migrate_to_device( t, device ):
+    if t is None:
+        return t
+
+    if not hasattr(t, 'device'):
+        t.device = device
+        t.manually_track_device = True
+    elif t.device == device:
+        return t
+
+    if hasattr(t, 'manually_track_device') and t.manually_track_device:
+        t.device = device
+
+    t = t.to(device)
+    
+    do_gc()
+
+    return t
+
 class TextToSpeech:
     """
     Main entry point into Tortoise.
@@ -315,10 +334,11 @@ class TextToSpeech:
         self.rlg_diffusion = None
 
         if self.preloaded_tensors:
-            self.autoregressive = self.autoregressive.to(self.device)
-            self.diffusion = self.diffusion.to(self.device)
-            self.clvp = self.clvp.to(self.device)
-            self.vocoder = self.vocoder.to(self.device)
+            self.autoregressive = migrate_to_device( self.autoregressive, self.device )
+            self.diffusion = migrate_to_device( self.diffusion, self.device )
+            self.clvp = migrate_to_device( self.clvp, self.device )
+            self.vocoder = migrate_to_device( self.vocoder, self.device )
+
         self.loading = False
 
     def load_autoregressive_model(self, autoregressive_model_path):
@@ -341,7 +361,7 @@ class TextToSpeech:
         self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
         self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
         if self.preloaded_tensors:
-            self.autoregressive = self.autoregressive.to(self.device)
+            self.autoregressive = migrate_to_device( self.autoregressive, self.device )
 
         self.loading = False
         print(f"Loaded autoregressive model")
@@ -382,7 +402,7 @@ class TextToSpeech:
 
         self.vocoder.eval(inference=True)
         if self.preloaded_tensors:
-            self.vocoder = self.vocoder.to(self.device)
+            self.vocoder = migrate_to_device( self.vocoder, self.device )
         self.loading = False
         print(f"Loaded vocoder model")
 
@@ -393,7 +413,7 @@ class TextToSpeech:
         self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
         
         if self.preloaded_tensors:
-            self.cvvp = self.cvvp.to(self.device)
+            self.cvvp = migrate_to_device( self.cvvp, self.device )
 
     def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, slices=1, max_chunk_size=None, force_cpu=False):
         """
@@ -411,7 +431,7 @@ class TextToSpeech:
             if not isinstance(voice_samples, list):
                 voice_samples = [voice_samples]
             
-            voice_samples = [v.to(device) for v in voice_samples]
+            voice_samples = [migrate_to_device(v, device)  for v in voice_samples]
 
             resampler = torchaudio.transforms.Resample(
                 self.input_sample_rate,
@@ -420,24 +440,19 @@ class TextToSpeech:
                 rolloff=0.85,
                 resampling_method="kaiser_window",
                 beta=8.555504641634386,
-            )
+            ).to(device)
 
             samples = []
             auto_conds = []
             for sample in voice_samples:
                 auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate))
-                samples.append(resampler(sample.cpu()).to(device)) # icky no good, easier to do the resampling on CPU than figure out how to do it on GPU
+                samples.append(resampler(sample)) 
 
             auto_conds = torch.stack(auto_conds, dim=1)
 
-
-            self.autoregressive = self.autoregressive.to(device)
+            self.autoregressive = migrate_to_device( self.autoregressive, device )
             auto_latent = self.autoregressive.get_conditioning(auto_conds)
-            if self.preloaded_tensors:
-                self.autoregressive = self.autoregressive.to(self.device)
-            else:
-                self.autoregressive = self.autoregressive.cpu()
-
+            self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
             
             diffusion_conds = []
             chunks = []
@@ -460,21 +475,14 @@ class TextToSpeech:
             for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
                 check_for_kill_signal()
                 chunk = pad_or_truncate(chunk, chunk_size)
-                cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device)
+                cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
                 diffusion_conds.append(cond_mel)
 
             diffusion_conds = torch.stack(diffusion_conds, dim=1)
 
-            self.diffusion = self.diffusion.to(device)
-            
+            self.diffusion = migrate_to_device( self.diffusion, device )
             diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
-
-            if self.preloaded_tensors:
-                self.diffusion = self.diffusion.to(self.device)
-            else:
-                self.diffusion = self.diffusion.cpu()
-
-
+            self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )
 
         if return_mels:
             return auto_latent, diffusion_latent, auto_conds, diffusion_conds
@@ -587,7 +595,9 @@ class TextToSpeech:
         elif autoregressive_model != self.autoregressive_model_path:
             self.load_autoregressive_model(autoregressive_model)
 
-        text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
+        text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0)
+        text_tokens = migrate_to_device( text_tokens, self.device )
+
         text_tokens = F.pad(text_tokens, (0, 1))  # This may not be necessary.
         assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
 
@@ -615,9 +625,9 @@ class TextToSpeech:
             stop_mel_token = self.autoregressive.stop_mel_token
             calm_token = 83  # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
 
-            self.autoregressive = self.autoregressive.to(self.device)
-            auto_conditioning = auto_conditioning.to(self.device)
-            text_tokens = text_tokens.to(self.device)
+            self.autoregressive = migrate_to_device( self.autoregressive, self.device )
+            auto_conditioning = migrate_to_device( auto_conditioning, self.device )
+            text_tokens = migrate_to_device( text_tokens, self.device )
 
             with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
                 for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
@@ -636,24 +646,24 @@ class TextToSpeech:
                     samples.append(codes)
 
             if not self.preloaded_tensors:
-                self.autoregressive = self.autoregressive.cpu()
-                auto_conditioning = auto_conditioning.cpu()
+                self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
 
             clip_results = []
 
             if auto_conds is not None:
-                auto_conds = auto_conds.to(self.device)
+                auto_conditioning = migrate_to_device( auto_conditioning, self.device )
 
             with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
-                if not self.minor_optimizations:
-                    self.autoregressive = self.autoregressive.cpu()
-                    self.clvp = self.clvp.to(self.device)
+                if not self.preloaded_tensors:
+                    self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
+                    self.clvp = migrate_to_device( self.clvp, self.device )
 
                 if cvvp_amount > 0:
                     if self.cvvp is None:
                         self.load_cvvp()
-                    if not self.minor_optimizations:
-                        self.cvvp = self.cvvp.to(self.device)
+                    
+                    if not self.preloaded_tensors:
+                        self.cvvp = migrate_to_device( self.cvvp, self.device )
                 
                 desc="Computing best candidates"
                 if verbose:
@@ -662,6 +672,7 @@ class TextToSpeech:
                     else:
                         desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
 
+                
                 for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
                     check_for_kill_signal()
                     for i in range(batch.shape[0]):
@@ -683,30 +694,28 @@ class TextToSpeech:
                         clip_results.append(clvp)
 
             if not self.preloaded_tensors and auto_conds is not None:
-                auto_conds = auto_conds.cpu()
+                auto_conds = migrate_to_device( auto_conds, 'cpu' )
 
             clip_results = torch.cat(clip_results, dim=0)
             samples = torch.cat(samples, dim=0)
             best_results = samples[torch.topk(clip_results, k=k).indices]
             
             if not self.preloaded_tensors:
-                self.clvp = self.clvp.cpu()
-                if self.cvvp is not None:
-                    self.cvvp = self.cvvp.cpu()
-
-            del samples
+                self.clvp = migrate_to_device( self.clvp, 'cpu' )
+                self.cvvp = migrate_to_device( self.cvvp, 'cpu' )
+            
 
             if get_device_name() == "dml":
-                text_tokens = text_tokens.cpu()
-                best_results = best_results.cpu()
-                auto_conditioning = auto_conditioning.cpu()
-                self.autoregressive = self.autoregressive.cpu()
+                text_tokens = migrate_to_device( text_tokens, 'cpu' )
+                best_results = migrate_to_device( best_results, 'cpu' )
+                auto_conditioning = migrate_to_device( auto_conditioning, 'cpu' )
+                self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
             else:
-                #text_tokens = text_tokens.to(self.device)
-                #best_results = best_results.to(self.device)
                 auto_conditioning = auto_conditioning.to(self.device)
                 self.autoregressive = self.autoregressive.to(self.device)
 
+            del samples
+
             # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
             # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
             # results, but will increase memory usage.
@@ -715,21 +724,19 @@ class TextToSpeech:
                                                torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
                                                return_latent=True, clip_inputs=False)
             
-            diffusion_conditioning = diffusion_conditioning.to(self.device)
+            diffusion_conditioning = migrate_to_device( diffusion_conditioning, self.device )
 
             if get_device_name() == "dml":
-                self.autoregressive = self.autoregressive.to(self.device)
-                best_results = best_results.to(self.device)
-                best_latents = best_latents.to(self.device)
-
-                self.vocoder = self.vocoder.cpu()
+                self.autoregressive = migrate_to_device( self.autoregressive, self.device )
+                best_results = migrate_to_device( best_results, self.device )
+                best_latents = migrate_to_device( best_latents, self.device )
+                self.vocoder = migrate_to_device( self.vocoder, 'cpu' )
             else:
                 if not self.preloaded_tensors:
-                    self.autoregressive = self.autoregressive.cpu()
-
-                self.diffusion = self.diffusion.to(self.device)
-                self.vocoder = self.vocoder.to(self.device)
+                    self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
 
+                self.diffusion = migrate_to_device( self.diffusion, self.device )
+                self.vocoder = migrate_to_device( self.vocoder, self.device )
             
             del text_tokens
             del auto_conditioning
@@ -758,12 +765,14 @@ class TextToSpeech:
                 wav_candidates.append(wav)
             
             if not self.preloaded_tensors:
-                self.diffusion = self.diffusion.cpu()
-                self.vocoder = self.vocoder.cpu()
+                self.diffusion = migrate_to_device( self.diffusion, 'cpu' )
+                self.vocoder = migrate_to_device( self.vocoder, 'cpu' )
 
             def potentially_redact(clip, text):
                 if self.enable_redaction:
-                    return self.aligner.redact(clip.squeeze(1).to('cpu' if get_device_name() == "dml" else self.device), text, self.output_sample_rate).unsqueeze(1)
+                    t = clip.squeeze(1)
+                    t = migrate_to_device( t, 'cpu' if get_device_name() == "dml" else self.device)
+                    return self.aligner.redact(t, text, self.output_sample_rate).unsqueeze(1)
                 return clip
             wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
 
@@ -772,7 +781,7 @@ class TextToSpeech:
             else:
                 res = wav_candidates[0]
 
-            gc.collect()
+            do_gc()
 
             if return_deterministic_state:
                 return res, (deterministic_seed, text, voice_samples, conditioning_latents)
diff --git a/tortoise/models/clvp.py b/tortoise/models/clvp.py
index 71a744c..db70fe2 100644
--- a/tortoise/models/clvp.py
+++ b/tortoise/models/clvp.py
@@ -9,6 +9,8 @@ from tortoise.models.xtransformers import Encoder
 
 import tortoise.utils.torch_intermediary as ml
 
+from tortoise.utils.device import print_stats, do_gc
+
 def exists(val):
     return val is not None
 
@@ -124,14 +126,13 @@ class CLVP(nn.Module):
             text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
             speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
 
-        enc_text = self.text_transformer(text_emb, mask=text_mask)
-        enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
+        
+        text_latents = self.to_text_latent(masked_mean(self.text_transformer(text_emb, mask=text_mask), text_mask, dim=1))
 
-        text_latents = masked_mean(enc_text, text_mask, dim=1)
-        speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
-
-        text_latents = self.to_text_latent(text_latents)
-        speech_latents = self.to_speech_latent(speech_latents)
+        # on ROCm at least, allocated VRAM spikes here
+        do_gc()
+        speech_latents = self.to_speech_latent(masked_mean(self.speech_transformer(speech_emb, mask=voice_mask), voice_mask, dim=1))
+        do_gc()
 
         text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
 
diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py
index 366a3af..7b16e7f 100755
--- a/tortoise/utils/device.py
+++ b/tortoise/utils/device.py
@@ -5,6 +5,29 @@ import importlib
 DEVICE_OVERRIDE = None
 DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
 
+from inspect import currentframe, getframeinfo
+import gc
+
+def do_gc():
+    gc.collect()
+    try:
+        torch.cuda.empty_cache()
+    except Exception as e:
+        pass
+
+def print_stats(collect=False):
+    cf = currentframe().f_back
+    msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
+
+    if collect:
+        do_gc()
+
+    tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
+    res = torch.cuda.memory_reserved(0) / (1024 ** 3)
+    alloc = torch.cuda.memory_allocated(0) / (1024 ** 3)
+    print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res ))
+
+
 def has_dml():
     loader = importlib.find_loader('torch_directml')
     if loader is None: