diff --git a/README.md b/README.md index c169d06..709e754 100755 --- a/README.md +++ b/README.md @@ -1,13 +1,9 @@ # AI Voice Cloning for Retards and Savants -This [rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows (with an Nvidia GPU), as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts). +This [rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows, as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts). Similar to my own findings for Stable Diffusion image generation, this rentry may appear a little disheveled as I note my new findings with TorToiSe. Please keep this in mind if the guide seems to shift a bit or sound confusing. ->\>B-but what about the colab notebook/hugging space instance?? - -I link those a bit later on as alternatives for Windows+AMD users. You're free to skip the installation section and jump after that. - >\>Ugh... why bother when I can just abuse 11.AI? I very much encourage (You) to use 11.AI while it's still viable to use. For the layman, it's easier to go through the hoops of coughing up the $5 or abusing the free trial over actually setting up a TorToiSe environment and dealing with its quirks. @@ -39,16 +35,15 @@ My fork boasts the following additions, fixes, and optimizations: - additionally, regenerating them if the script detects they're out of date * uses the entire audio sample instead of the first four seconds of each sound file for better reproducing * activated unused DDIM sampler -* ease of setup for the most inexperienced Windows users * use of some optimizations like `kv_cache`ing for the autoregression sample pass, and keeping data on GPU +* compatability with DirectML +* easy install scripts * and more! ## Installing Outside of the very small prerequisites, everything needed to get TorToiSe working is included in the repo. -For Windows users with an AMD GPU, ~~tough luck, as ROCm drivers are not (easily) available for Windows, and requires inane patches with PyTorch.~~ you're almost in luck, as hardware acceleration for any\* device is possible with PyTorch-DirectML. **!**NOTE**!**: DirectML support is currently being worked on, so for now, consider using the [Colab notebook](https://colab.research.google.com/drive/1wVVqUPqwiDBUVeWWOUNglpGhU3hg_cbR?usp=sharing), or the [Hugging Face space](https://huggingface.co/spaces/mdnestor/tortoise), for `tortoise-tts`. **!**NOTE**!**: these two do not use this repo's fork. - ### Pre-Requirements Windows: @@ -71,16 +66,22 @@ After installing Python, open the Start Menu and search for `Command Prompt`. Ty Paste `git clone https://git.ecker.tech/mrq/tortoise-tts` to download TorToiSe and additional scripts, then hit Enter. Inexperienced users can just download the repo as a ZIP, and extract. Afterwards, run the setup script, depending on your GPU, to automatically set things up. -* ~~AMD: `setup-directml.bat`~~ +* AMD: `setup-directml.bat` * NVIDIA: `setup-cuda.bat` If you've done everything right, you shouldn't have any errors. ##### Note on DirectML Support -At first, I thought it was just one simple problem that needed to be fixed, but as I picked at it and did a new install (having CUDA enabled too caused some things to silently "work" despite using DML instead), more problems cropped up, exposing that PyTorch-DirectML isn't quite ready yet. +PyTorch-DirectML is very, very experimental and is still not production quality. There's some headaches with the need for hairy kludgy patches. -I doubt even if I sucked off a wizard, there'd still be other problems cropping up. +These patches rely on transfering the tensor between the GPU and CPU as a hotfix, so performance is definitely harmed. + +Both the conditional latent computation and the vocoder pass have to be done on the CPU entirely because of some quirks with DirectML. + +On my 6800XT, VRAM usage climbs almost the entire 16GiB, so be wary if you OOM somehow. Low VRAM flags may NOT have any additional impact from the constant copying anyways. + +For AMD users, I still might suggest using Linux+ROCm as it's (relatively) headache free, but I had stability problems. #### Linux diff --git a/start.bat b/start.bat index 33bc1b3..a5159e1 100755 --- a/start.bat +++ b/start.bat @@ -1,4 +1,4 @@ call .\tortoise-venv\Scripts\activate.bat -python .\app.py +accelerate launch --num_cpu_threads_per_process=6 app.py deactivate pause \ No newline at end of file diff --git a/tortoise/api.py b/tortoise/api.py index 7b99857..cd2c384 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -176,7 +176,10 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_la model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, verbose=verbose, progress=progress, desc=desc) - return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + mel = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + if get_device_name() == "dml": + mel = mel.cpu() + return mel def classify_audio_clip(clip): @@ -449,6 +452,9 @@ class TextToSpeech: :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ + if get_device_name() == "dml": + half_p = False + self.diffusion.enable_fp16 = half_p deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) @@ -477,6 +483,8 @@ class TextToSpeech: with torch.no_grad(): samples = [] num_batches = num_autoregressive_samples // self.autoregressive_batch_size + if num_autoregressive_samples < self.autoregressive_batch_size: + num_autoregressive_samples = 1 stop_mel_token = self.autoregressive.stop_mel_token calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" @@ -553,16 +561,31 @@ class TextToSpeech: if not self.minor_optimizations: self.autoregressive = self.autoregressive.to(self.device) + if get_device_name() == "dml": + text_tokens = text_tokens.cpu() + best_results = best_results.cpu() + auto_conditioning = auto_conditioning.cpu() + self.autoregressive = self.autoregressive.cpu() + best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), return_latent=True, clip_inputs=False) + if get_device_name() == "dml": + self.autoregressive = self.autoregressive.to(self.device) + best_results = best_results.to(self.device) + best_latents = best_latents.to(self.device) + if not self.minor_optimizations: self.autoregressive = self.autoregressive.cpu() self.diffusion = self.diffusion.to(self.device) self.vocoder = self.vocoder.to(self.device) + if get_device_name() == "dml": + self.vocoder = self.vocoder.cpu() + + del text_tokens del auto_conditioning wav_candidates = [] @@ -584,6 +607,7 @@ class TextToSpeech: mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature, verbose=verbose, progress=progress, desc="Transforming autoregressive outputs into audio..", sampler=diffusion_sampler, input_sample_rate=self.input_sample_rate, output_sample_rate=self.output_sample_rate) + wav = self.vocoder.inference(mel) wav_candidates.append(wav.cpu()) diff --git a/tortoise/models/diffusion_decoder.py b/tortoise/models/diffusion_decoder.py index 551016b..b383914 100755 --- a/tortoise/models/diffusion_decoder.py +++ b/tortoise/models/diffusion_decoder.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import autocast from tortoise.models.arch_util import normalization, AttentionBlock - +from tortoise.utils.device import get_device_name def is_latent(t): return t.dtype == torch.float @@ -141,7 +141,7 @@ class DiffusionTts(nn.Module): in_tokens=8193, out_channels=200, # mean and variance dropout=0, - use_fp16=True, + use_fp16=False, num_heads=16, # Parameters for regularization. layer_drop=.1, @@ -302,7 +302,8 @@ class DiffusionTts(nn.Module): unused_params.extend(list(lyr.parameters())) else: # First and last blocks will have autocast disabled for improved precision. - with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): + # x.device.type + with autocast(device_type='cuda', enabled=self.enable_fp16 and i != 0): x = lyr(x, time_emb) x = x.float() diff --git a/tortoise/models/vocoder.py b/tortoise/models/vocoder.py old mode 100644 new mode 100755 diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index cb83926..db90969 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -1,37 +1,9 @@ import torch def has_dml(): - """ - # huggingface's transformer/GPT2 model will just lead to a long track of problems - # I will suck off a wizard if he gets this remedied somehow - """ - """ - # Note 1: - # self.inference_model.generate will lead to this error in torch.LongTensor.new: - # RuntimeError: new(): expected key in DispatchKeySet(CPU, CUDA, HIP, XLA, MPS, IPU, XPU, HPU, Lazy, Meta) but got: PrivateUse1 - # Patching "./venv/lib/site-packages/transformers/generation_utils.py:1906" with: - # unfinished_sequences = input_ids.new_tensor(input_ids.shape[0], device=input_ids.device).fill_(1) - # "fixes" it, but meets another error/crash about an unimplemented functions......... - """ - """ - # Note 2: - # torch.load() will gripe about something CUDA not existing - # remedy this with passing map_location="cpu" - """ - """ - # Note 3: - # stft requires device='cpu' or it'll crash about some error about an unimplemented function I do not remember - """ - """ - # Note 4: - # 'Tensor.multinominal' and 'Tensor.repeat_interleave' throws errors about being unimplemented and falls back to CPU and crashes - """ - return False - """ import importlib loader = importlib.find_loader('torch_directml') return loader is not None - """ def get_device_name(): name = 'cpu' @@ -68,4 +40,23 @@ def get_device_batch_size(): return 8 elif availableGb > 7: return 4 - return 1 \ No newline at end of file + return 1 + +if has_dml(): + _cumsum = torch.cumsum + _repeat_interleave = torch.repeat_interleave + _multinomial = torch.multinomial + + _Tensor_new = torch.Tensor.new + _Tensor_cumsum = torch.Tensor.cumsum + _Tensor_repeat_interleave = torch.Tensor.repeat_interleave + _Tensor_multinomial = torch.Tensor.multinomial + + torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) ) + torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) ) + torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) ) + + torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) ) + torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) ) + torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) ) \ No newline at end of file