added automagic offloading models to GPU then CPU when theyre done during inference

This commit is contained in:
mrq 2024-06-19 17:01:05 -05:00
parent 5d24631bfb
commit 73f271fb8a
7 changed files with 90 additions and 76 deletions

View File

@ -40,7 +40,7 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
- [ ] Reimplement redaction with the Wav2Vec2 - [ ] Reimplement redaction with the Wav2Vec2
- [X] Implement training support (without DLAS) - [X] Implement training support (without DLAS)
- [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time - [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
- [ ] Automagic offloading to CPU for unused models (for training and inferencing) - [X] Automagic offloading to CPU for unused models (for training and inferencing)
- [X] Automagic handling of the original weights into compatible weights - [X] Automagic handling of the original weights into compatible weights
- [ ] Reimplement added features from my original fork: - [ ] Reimplement added features from my original fork:
- [ ] "Better" conditioning latents calculating - [ ] "Better" conditioning latents calculating

View File

@ -19,7 +19,7 @@ def main():
parser.add_argument("--top-k", type=int, default=16) parser.add_argument("--top-k", type=int, default=16)
parser.add_argument("--repetition-penalty", type=float, default=1.0) parser.add_argument("--repetition-penalty", type=float, default=1.0)
#parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) #parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=1.0)
parser.add_argument("--beam-width", type=int, default=0) parser.add_argument("--beam-width", type=int, default=0)
parser.add_argument("--diffusion-sampler", type=str, default="ddim") parser.add_argument("--diffusion-sampler", type=str, default="ddim")

View File

@ -21,8 +21,6 @@ from .tokenizer import VoiceBpeTokenizer
# Yuck # Yuck
from transformers import PreTrainedTokenizerFast from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer
@dataclass() @dataclass()
class BaseConfig: class BaseConfig:
@ -472,17 +470,10 @@ class Inference:
weight_dtype: str = "float32" weight_dtype: str = "float32"
amp: bool = False amp: bool = False
auto_unload: bool = True
normalize: bool = False # do NOT enable this unless you know exactly what you're doing normalize: bool = False # do NOT enable this unless you know exactly what you're doing
# legacy / backwards compat
use_vocos: bool = True
use_encodec: bool = True
use_dac: bool = True
# shit that doesn't work
recurrent_chunk_size: int = 0
recurrent_forward: bool = False
@cached_property @cached_property
def dtype(self): def dtype(self):
if self.weight_dtype == "float16": if self.weight_dtype == "float16":

View File

@ -8,6 +8,7 @@ from pathlib import Path
from .emb.mel import encode_from_files as encode_mel, trim, trim_random from .emb.mel import encode_from_files as encode_mel, trim, trim_random
from .utils import to_device from .utils import to_device
from .utils import wrapper as ml
from .config import cfg from .config import cfg
from .models import get_models, load_model from .models import get_models, load_model
@ -110,7 +111,7 @@ class TTS():
top_k=0, top_k=0,
repetition_penalty=1.0, repetition_penalty=1.0,
#repetition_penalty_decay=0.0, #repetition_penalty_decay=0.0,
length_penalty=0.0, length_penalty=1.0,
beam_width=1, beam_width=1,
#mirostat_tau=0, #mirostat_tau=0,
#mirostat_eta=0.1, #mirostat_eta=0.1,
@ -151,6 +152,13 @@ class TTS():
if vocoder is None: if vocoder is None:
vocoder = load_model("vocoder", device=cfg.device) vocoder = load_model("vocoder", device=cfg.device)
# shove everything to cpu
if cfg.inference.auto_unload:
autoregressive = autoregressive.to("cpu")
diffusion = diffusion.to("cpu")
clvp = clvp.to("cpu")
vocoder = vocoder.to("cpu")
wavs = [] wavs = []
# other vars # other vars
calm_token = 832 calm_token = 832
@ -168,79 +176,82 @@ class TTS():
text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32) text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
# autoregressive pass with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload):
codes = autoregressive.inference_speech( # autoregressive pass
autoregressive_latents, codes = autoregressive.inference_speech(
text_tokens, autoregressive_latents,
do_sample=True, text_tokens,
top_p=top_p, do_sample=True,
temperature=ar_temp, top_p=top_p,
num_return_sequences=1, temperature=ar_temp,
length_penalty=length_penalty, num_return_sequences=1,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
max_generate_length=max_ar_steps, repetition_penalty=repetition_penalty,
) max_generate_length=max_ar_steps,
)
""" """
padding_needed = max_ar_steps - codes.shape[1] padding_needed = max_ar_steps - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token) codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
""" """
for i, code in enumerate( codes ): for i, code in enumerate( codes ):
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero() stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
stm = stop_token_indices.min().item() stm = stop_token_indices.min().item()
if len(stop_token_indices) == 0: if len(stop_token_indices) == 0:
continue continue
codes[i][stop_token_indices] = 83 codes[i][stop_token_indices] = 83
codes[i][stm:] = 83 codes[i][stm:] = 83
if stm - 3 < codes[i].shape[0]: if stm - 3 < codes[i].shape[0]:
codes[i][-3] = 45 codes[i][-3] = 45
codes[i][-2] = 45 codes[i][-2] = 45
codes[i][-1] = 248 codes[i][-1] = 248
wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device) wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device)
latents = autoregressive.forward( latents = autoregressive.forward(
autoregressive_latents, autoregressive_latents,
text_tokens, text_tokens,
text_lengths, text_lengths,
codes, codes,
wav_lengths, wav_lengths,
return_latent=True, return_latent=True,
clip_inputs=False clip_inputs=False
) )
calm_tokens = 0 calm_tokens = 0
for k in range( codes.shape[-1] ): for k in range( codes.shape[-1] ):
if codes[0, k] == calm_token: if codes[0, k] == calm_token:
calm_tokens += 1 calm_tokens += 1
else: else:
calm_tokens = 0 calm_tokens = 0
if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k] latents = latents[:, :k]
break break
# diffusion pass # diffusion pass
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. with ml.auto_unload(diffusion, enabled=cfg.inference.auto_unload):
output_shape = (latents.shape[0], 100, output_seq_len) output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False) output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
mel = diffuser.sample_loop( mel = diffuser.sample_loop(
diffusion, diffusion,
output_shape, output_shape,
sampler=diffusion_sampler, sampler=diffusion_sampler,
noise=noise, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
progress=True progress=True
) )
mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
# vocoder pass # vocoder pass
waves = vocoder.inference(mels) with ml.auto_unload(vocoder, enabled=cfg.inference.auto_unload):
waves = vocoder.inference(mels)
for wav in waves: for wav in waves:
if out_path is not None: if out_path is not None:

View File

@ -229,6 +229,10 @@ def run_eval(engines, eval_name, dl):
else: else:
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
diffusion = diffusion.to("cpu")
clvp = clvp.to("cpu")
vocoder = vocoder.to("cpu")
def train(): def train():
parser = argparse.ArgumentParser("TorToiSe TTS") parser = argparse.ArgumentParser("TorToiSe TTS")

View File

@ -77,6 +77,14 @@ def autocasts(input, from_dtype, to_dtype):
else: else:
yield input yield input
@contextmanager
def auto_unload( model, gpu="cuda", cpu="cpu", enabled=True):
model.to(gpu)
yield model
if enabled:
model.to(cpu)
# handles temporarily upcasting 'index tensors' so torch will stop bitching # handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ): def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ): def wrapper( self, input, *args, **kwargs ):

View File

@ -240,7 +240,7 @@ with ui:
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
""" """
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.") layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")