diff --git a/tortoise/api.py b/tortoise/api.py index b663fa6..2df47ea 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -29,7 +29,7 @@ from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named from tortoise.utils.tokenizer import VoiceBpeTokenizer from tortoise.utils.wav2vec_alignment import Wav2VecAlignment -from tortoise.utils.device import get_device, get_device_name, get_device_batch_size +from tortoise.utils.device import get_device, get_device_name, get_device_batch_size, print_stats, do_gc pbar = None STOP_SIGNAL = False @@ -172,8 +172,8 @@ def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=2 rand_start = random.randint(0, gap) clip = clip[:, rand_start:rand_start + cond_length] mel_clip = TorchMelSpectrogram(sampling_rate=sampling_rate)(clip.unsqueeze(0)).squeeze(0) - return mel_clip.unsqueeze(0).to(device) - + mel_clip = mel_clip.unsqueeze(0) + return migrate_to_device(mel_clip, device) def fix_autoregressive_output(codes, stop_token, complain=True): """ @@ -241,6 +241,25 @@ def classify_audio_clip(clip): results = F.softmax(classifier(clip), dim=-1) return results[0][0] +def migrate_to_device( t, device ): + if t is None: + return t + + if not hasattr(t, 'device'): + t.device = device + t.manually_track_device = True + elif t.device == device: + return t + + if hasattr(t, 'manually_track_device') and t.manually_track_device: + t.device = device + + t = t.to(device) + + do_gc() + + return t + class TextToSpeech: """ Main entry point into Tortoise. @@ -315,10 +334,11 @@ class TextToSpeech: self.rlg_diffusion = None if self.preloaded_tensors: - self.autoregressive = self.autoregressive.to(self.device) - self.diffusion = self.diffusion.to(self.device) - self.clvp = self.clvp.to(self.device) - self.vocoder = self.vocoder.to(self.device) + self.autoregressive = migrate_to_device( self.autoregressive, self.device ) + self.diffusion = migrate_to_device( self.diffusion, self.device ) + self.clvp = migrate_to_device( self.clvp, self.device ) + self.vocoder = migrate_to_device( self.vocoder, self.device ) + self.loading = False def load_autoregressive_model(self, autoregressive_model_path): @@ -341,7 +361,7 @@ class TextToSpeech: self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache) if self.preloaded_tensors: - self.autoregressive = self.autoregressive.to(self.device) + self.autoregressive = migrate_to_device( self.autoregressive, self.device ) self.loading = False print(f"Loaded autoregressive model") @@ -382,7 +402,7 @@ class TextToSpeech: self.vocoder.eval(inference=True) if self.preloaded_tensors: - self.vocoder = self.vocoder.to(self.device) + self.vocoder = migrate_to_device( self.vocoder, self.device ) self.loading = False print(f"Loaded vocoder model") @@ -393,7 +413,7 @@ class TextToSpeech: self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir))) if self.preloaded_tensors: - self.cvvp = self.cvvp.to(self.device) + self.cvvp = migrate_to_device( self.cvvp, self.device ) def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, slices=1, max_chunk_size=None, force_cpu=False): """ @@ -411,7 +431,7 @@ class TextToSpeech: if not isinstance(voice_samples, list): voice_samples = [voice_samples] - voice_samples = [v.to(device) for v in voice_samples] + voice_samples = [migrate_to_device(v, device) for v in voice_samples] resampler = torchaudio.transforms.Resample( self.input_sample_rate, @@ -420,24 +440,19 @@ class TextToSpeech: rolloff=0.85, resampling_method="kaiser_window", beta=8.555504641634386, - ) + ).to(device) samples = [] auto_conds = [] for sample in voice_samples: auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate)) - samples.append(resampler(sample.cpu()).to(device)) # icky no good, easier to do the resampling on CPU than figure out how to do it on GPU + samples.append(resampler(sample)) auto_conds = torch.stack(auto_conds, dim=1) - - self.autoregressive = self.autoregressive.to(device) + self.autoregressive = migrate_to_device( self.autoregressive, device ) auto_latent = self.autoregressive.get_conditioning(auto_conds) - if self.preloaded_tensors: - self.autoregressive = self.autoregressive.to(self.device) - else: - self.autoregressive = self.autoregressive.cpu() - + self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' ) diffusion_conds = [] chunks = [] @@ -460,21 +475,14 @@ class TextToSpeech: for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."): check_for_kill_signal() chunk = pad_or_truncate(chunk, chunk_size) - cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device) + cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device) diffusion_conds.append(cond_mel) diffusion_conds = torch.stack(diffusion_conds, dim=1) - self.diffusion = self.diffusion.to(device) - + self.diffusion = migrate_to_device( self.diffusion, device ) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) - - if self.preloaded_tensors: - self.diffusion = self.diffusion.to(self.device) - else: - self.diffusion = self.diffusion.cpu() - - + self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' ) if return_mels: return auto_latent, diffusion_latent, auto_conds, diffusion_conds @@ -587,7 +595,9 @@ class TextToSpeech: elif autoregressive_model != self.autoregressive_model_path: self.load_autoregressive_model(autoregressive_model) - text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0) + text_tokens = migrate_to_device( text_tokens, self.device ) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' @@ -615,9 +625,9 @@ class TextToSpeech: stop_mel_token = self.autoregressive.stop_mel_token calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" - self.autoregressive = self.autoregressive.to(self.device) - auto_conditioning = auto_conditioning.to(self.device) - text_tokens = text_tokens.to(self.device) + self.autoregressive = migrate_to_device( self.autoregressive, self.device ) + auto_conditioning = migrate_to_device( auto_conditioning, self.device ) + text_tokens = migrate_to_device( text_tokens, self.device ) with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p): for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"): @@ -636,24 +646,24 @@ class TextToSpeech: samples.append(codes) if not self.preloaded_tensors: - self.autoregressive = self.autoregressive.cpu() - auto_conditioning = auto_conditioning.cpu() + self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' ) clip_results = [] if auto_conds is not None: - auto_conds = auto_conds.to(self.device) + auto_conditioning = migrate_to_device( auto_conditioning, self.device ) with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p): - if not self.minor_optimizations: - self.autoregressive = self.autoregressive.cpu() - self.clvp = self.clvp.to(self.device) + if not self.preloaded_tensors: + self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' ) + self.clvp = migrate_to_device( self.clvp, self.device ) if cvvp_amount > 0: if self.cvvp is None: self.load_cvvp() - if not self.minor_optimizations: - self.cvvp = self.cvvp.to(self.device) + + if not self.preloaded_tensors: + self.cvvp = migrate_to_device( self.cvvp, self.device ) desc="Computing best candidates" if verbose: @@ -662,6 +672,7 @@ class TextToSpeech: else: desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%" + for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc): check_for_kill_signal() for i in range(batch.shape[0]): @@ -683,30 +694,28 @@ class TextToSpeech: clip_results.append(clvp) if not self.preloaded_tensors and auto_conds is not None: - auto_conds = auto_conds.cpu() + auto_conds = migrate_to_device( auto_conds, 'cpu' ) clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) best_results = samples[torch.topk(clip_results, k=k).indices] if not self.preloaded_tensors: - self.clvp = self.clvp.cpu() - if self.cvvp is not None: - self.cvvp = self.cvvp.cpu() - - del samples + self.clvp = migrate_to_device( self.clvp, 'cpu' ) + self.cvvp = migrate_to_device( self.cvvp, 'cpu' ) + if get_device_name() == "dml": - text_tokens = text_tokens.cpu() - best_results = best_results.cpu() - auto_conditioning = auto_conditioning.cpu() - self.autoregressive = self.autoregressive.cpu() + text_tokens = migrate_to_device( text_tokens, 'cpu' ) + best_results = migrate_to_device( best_results, 'cpu' ) + auto_conditioning = migrate_to_device( auto_conditioning, 'cpu' ) + self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' ) else: - #text_tokens = text_tokens.to(self.device) - #best_results = best_results.to(self.device) auto_conditioning = auto_conditioning.to(self.device) self.autoregressive = self.autoregressive.to(self.device) + del samples + # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. @@ -715,21 +724,19 @@ class TextToSpeech: torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), return_latent=True, clip_inputs=False) - diffusion_conditioning = diffusion_conditioning.to(self.device) + diffusion_conditioning = migrate_to_device( diffusion_conditioning, self.device ) if get_device_name() == "dml": - self.autoregressive = self.autoregressive.to(self.device) - best_results = best_results.to(self.device) - best_latents = best_latents.to(self.device) - - self.vocoder = self.vocoder.cpu() + self.autoregressive = migrate_to_device( self.autoregressive, self.device ) + best_results = migrate_to_device( best_results, self.device ) + best_latents = migrate_to_device( best_latents, self.device ) + self.vocoder = migrate_to_device( self.vocoder, 'cpu' ) else: if not self.preloaded_tensors: - self.autoregressive = self.autoregressive.cpu() - - self.diffusion = self.diffusion.to(self.device) - self.vocoder = self.vocoder.to(self.device) + self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' ) + self.diffusion = migrate_to_device( self.diffusion, self.device ) + self.vocoder = migrate_to_device( self.vocoder, self.device ) del text_tokens del auto_conditioning @@ -758,12 +765,14 @@ class TextToSpeech: wav_candidates.append(wav) if not self.preloaded_tensors: - self.diffusion = self.diffusion.cpu() - self.vocoder = self.vocoder.cpu() + self.diffusion = migrate_to_device( self.diffusion, 'cpu' ) + self.vocoder = migrate_to_device( self.vocoder, 'cpu' ) def potentially_redact(clip, text): if self.enable_redaction: - return self.aligner.redact(clip.squeeze(1).to('cpu' if get_device_name() == "dml" else self.device), text, self.output_sample_rate).unsqueeze(1) + t = clip.squeeze(1) + t = migrate_to_device( t, 'cpu' if get_device_name() == "dml" else self.device) + return self.aligner.redact(t, text, self.output_sample_rate).unsqueeze(1) return clip wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] @@ -772,7 +781,7 @@ class TextToSpeech: else: res = wav_candidates[0] - gc.collect() + do_gc() if return_deterministic_state: return res, (deterministic_seed, text, voice_samples, conditioning_latents) diff --git a/tortoise/models/clvp.py b/tortoise/models/clvp.py index 71a744c..db70fe2 100644 --- a/tortoise/models/clvp.py +++ b/tortoise/models/clvp.py @@ -9,6 +9,8 @@ from tortoise.models.xtransformers import Encoder import tortoise.utils.torch_intermediary as ml +from tortoise.utils.device import print_stats, do_gc + def exists(val): return val is not None @@ -124,14 +126,13 @@ class CLVP(nn.Module): text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) - enc_text = self.text_transformer(text_emb, mask=text_mask) - enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) + + text_latents = self.to_text_latent(masked_mean(self.text_transformer(text_emb, mask=text_mask), text_mask, dim=1)) - text_latents = masked_mean(enc_text, text_mask, dim=1) - speech_latents = masked_mean(enc_speech, voice_mask, dim=1) - - text_latents = self.to_text_latent(text_latents) - speech_latents = self.to_speech_latent(speech_latents) + # on ROCm at least, allocated VRAM spikes here + do_gc() + speech_latents = self.to_speech_latent(masked_mean(self.speech_transformer(speech_emb, mask=voice_mask), voice_mask, dim=1)) + do_gc() text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index 366a3af..7b16e7f 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -5,6 +5,29 @@ import importlib DEVICE_OVERRIDE = None DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] +from inspect import currentframe, getframeinfo +import gc + +def do_gc(): + gc.collect() + try: + torch.cuda.empty_cache() + except Exception as e: + pass + +def print_stats(collect=False): + cf = currentframe().f_back + msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}' + + if collect: + do_gc() + + tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) + res = torch.cuda.memory_reserved(0) / (1024 ** 3) + alloc = torch.cuda.memory_allocated(0) / (1024 ** 3) + print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res )) + + def has_dml(): loader = importlib.find_loader('torch_directml') if loader is None: