added automagic offloading models to GPU then CPU when theyre done during inference
This commit is contained in:
parent
5d24631bfb
commit
73f271fb8a
|
@ -40,7 +40,7 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
|
||||||
- [ ] Reimplement redaction with the Wav2Vec2
|
- [ ] Reimplement redaction with the Wav2Vec2
|
||||||
- [X] Implement training support (without DLAS)
|
- [X] Implement training support (without DLAS)
|
||||||
- [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
|
- [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
|
||||||
- [ ] Automagic offloading to CPU for unused models (for training and inferencing)
|
- [X] Automagic offloading to CPU for unused models (for training and inferencing)
|
||||||
- [X] Automagic handling of the original weights into compatible weights
|
- [X] Automagic handling of the original weights into compatible weights
|
||||||
- [ ] Reimplement added features from my original fork:
|
- [ ] Reimplement added features from my original fork:
|
||||||
- [ ] "Better" conditioning latents calculating
|
- [ ] "Better" conditioning latents calculating
|
||||||
|
|
|
@ -19,7 +19,7 @@ def main():
|
||||||
parser.add_argument("--top-k", type=int, default=16)
|
parser.add_argument("--top-k", type=int, default=16)
|
||||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||||
#parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
#parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
parser.add_argument("--length-penalty", type=float, default=1.0)
|
||||||
parser.add_argument("--beam-width", type=int, default=0)
|
parser.add_argument("--beam-width", type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument("--diffusion-sampler", type=str, default="ddim")
|
parser.add_argument("--diffusion-sampler", type=str, default="ddim")
|
||||||
|
|
|
@ -21,8 +21,6 @@ from .tokenizer import VoiceBpeTokenizer
|
||||||
|
|
||||||
# Yuck
|
# Yuck
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
from tokenizers import Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BaseConfig:
|
class BaseConfig:
|
||||||
|
@ -472,17 +470,10 @@ class Inference:
|
||||||
weight_dtype: str = "float32"
|
weight_dtype: str = "float32"
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
|
||||||
|
auto_unload: bool = True
|
||||||
|
|
||||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||||
|
|
||||||
# legacy / backwards compat
|
|
||||||
use_vocos: bool = True
|
|
||||||
use_encodec: bool = True
|
|
||||||
use_dac: bool = True
|
|
||||||
|
|
||||||
# shit that doesn't work
|
|
||||||
recurrent_chunk_size: int = 0
|
|
||||||
recurrent_forward: bool = False
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
if self.weight_dtype == "float16":
|
||||||
|
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
|
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
|
||||||
from .utils import to_device
|
from .utils import to_device
|
||||||
|
from .utils import wrapper as ml
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .models import get_models, load_model
|
from .models import get_models, load_model
|
||||||
|
@ -110,7 +111,7 @@ class TTS():
|
||||||
top_k=0,
|
top_k=0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
#repetition_penalty_decay=0.0,
|
#repetition_penalty_decay=0.0,
|
||||||
length_penalty=0.0,
|
length_penalty=1.0,
|
||||||
beam_width=1,
|
beam_width=1,
|
||||||
#mirostat_tau=0,
|
#mirostat_tau=0,
|
||||||
#mirostat_eta=0.1,
|
#mirostat_eta=0.1,
|
||||||
|
@ -151,6 +152,13 @@ class TTS():
|
||||||
if vocoder is None:
|
if vocoder is None:
|
||||||
vocoder = load_model("vocoder", device=cfg.device)
|
vocoder = load_model("vocoder", device=cfg.device)
|
||||||
|
|
||||||
|
# shove everything to cpu
|
||||||
|
if cfg.inference.auto_unload:
|
||||||
|
autoregressive = autoregressive.to("cpu")
|
||||||
|
diffusion = diffusion.to("cpu")
|
||||||
|
clvp = clvp.to("cpu")
|
||||||
|
vocoder = vocoder.to("cpu")
|
||||||
|
|
||||||
wavs = []
|
wavs = []
|
||||||
# other vars
|
# other vars
|
||||||
calm_token = 832
|
calm_token = 832
|
||||||
|
@ -168,79 +176,82 @@ class TTS():
|
||||||
text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32)
|
text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
# autoregressive pass
|
with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload):
|
||||||
codes = autoregressive.inference_speech(
|
# autoregressive pass
|
||||||
autoregressive_latents,
|
codes = autoregressive.inference_speech(
|
||||||
text_tokens,
|
autoregressive_latents,
|
||||||
do_sample=True,
|
text_tokens,
|
||||||
top_p=top_p,
|
do_sample=True,
|
||||||
temperature=ar_temp,
|
top_p=top_p,
|
||||||
num_return_sequences=1,
|
temperature=ar_temp,
|
||||||
length_penalty=length_penalty,
|
num_return_sequences=1,
|
||||||
repetition_penalty=repetition_penalty,
|
length_penalty=length_penalty,
|
||||||
max_generate_length=max_ar_steps,
|
repetition_penalty=repetition_penalty,
|
||||||
)
|
max_generate_length=max_ar_steps,
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
padding_needed = max_ar_steps - codes.shape[1]
|
padding_needed = max_ar_steps - codes.shape[1]
|
||||||
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
|
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for i, code in enumerate( codes ):
|
for i, code in enumerate( codes ):
|
||||||
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
|
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
|
||||||
stm = stop_token_indices.min().item()
|
stm = stop_token_indices.min().item()
|
||||||
|
|
||||||
if len(stop_token_indices) == 0:
|
if len(stop_token_indices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
codes[i][stop_token_indices] = 83
|
codes[i][stop_token_indices] = 83
|
||||||
codes[i][stm:] = 83
|
codes[i][stm:] = 83
|
||||||
|
|
||||||
if stm - 3 < codes[i].shape[0]:
|
if stm - 3 < codes[i].shape[0]:
|
||||||
codes[i][-3] = 45
|
codes[i][-3] = 45
|
||||||
codes[i][-2] = 45
|
codes[i][-2] = 45
|
||||||
codes[i][-1] = 248
|
codes[i][-1] = 248
|
||||||
|
|
||||||
wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device)
|
wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device)
|
||||||
|
|
||||||
latents = autoregressive.forward(
|
latents = autoregressive.forward(
|
||||||
autoregressive_latents,
|
autoregressive_latents,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
text_lengths,
|
text_lengths,
|
||||||
codes,
|
codes,
|
||||||
wav_lengths,
|
wav_lengths,
|
||||||
return_latent=True,
|
return_latent=True,
|
||||||
clip_inputs=False
|
clip_inputs=False
|
||||||
)
|
)
|
||||||
|
|
||||||
calm_tokens = 0
|
calm_tokens = 0
|
||||||
for k in range( codes.shape[-1] ):
|
for k in range( codes.shape[-1] ):
|
||||||
if codes[0, k] == calm_token:
|
if codes[0, k] == calm_token:
|
||||||
calm_tokens += 1
|
calm_tokens += 1
|
||||||
else:
|
else:
|
||||||
calm_tokens = 0
|
calm_tokens = 0
|
||||||
if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||||
latents = latents[:, :k]
|
latents = latents[:, :k]
|
||||||
break
|
break
|
||||||
|
|
||||||
# diffusion pass
|
# diffusion pass
|
||||||
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
with ml.auto_unload(diffusion, enabled=cfg.inference.auto_unload):
|
||||||
output_shape = (latents.shape[0], 100, output_seq_len)
|
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||||
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
|
output_shape = (latents.shape[0], 100, output_seq_len)
|
||||||
|
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
|
||||||
|
|
||||||
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
|
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
|
||||||
mel = diffuser.sample_loop(
|
mel = diffuser.sample_loop(
|
||||||
diffusion,
|
diffusion,
|
||||||
output_shape,
|
output_shape,
|
||||||
sampler=diffusion_sampler,
|
sampler=diffusion_sampler,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
|
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
|
||||||
progress=True
|
progress=True
|
||||||
)
|
)
|
||||||
mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
||||||
|
|
||||||
# vocoder pass
|
# vocoder pass
|
||||||
waves = vocoder.inference(mels)
|
with ml.auto_unload(vocoder, enabled=cfg.inference.auto_unload):
|
||||||
|
waves = vocoder.inference(mels)
|
||||||
|
|
||||||
for wav in waves:
|
for wav in waves:
|
||||||
if out_path is not None:
|
if out_path is not None:
|
||||||
|
|
|
@ -229,6 +229,10 @@ def run_eval(engines, eval_name, dl):
|
||||||
else:
|
else:
|
||||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||||
|
|
||||||
|
diffusion = diffusion.to("cpu")
|
||||||
|
clvp = clvp.to("cpu")
|
||||||
|
vocoder = vocoder.to("cpu")
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
parser = argparse.ArgumentParser("TorToiSe TTS")
|
parser = argparse.ArgumentParser("TorToiSe TTS")
|
||||||
|
|
|
@ -77,6 +77,14 @@ def autocasts(input, from_dtype, to_dtype):
|
||||||
else:
|
else:
|
||||||
yield input
|
yield input
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def auto_unload( model, gpu="cuda", cpu="cpu", enabled=True):
|
||||||
|
model.to(gpu)
|
||||||
|
yield model
|
||||||
|
|
||||||
|
if enabled:
|
||||||
|
model.to(cpu)
|
||||||
|
|
||||||
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
||||||
def autocast_forward( func ):
|
def autocast_forward( func ):
|
||||||
def wrapper( self, input, *args, **kwargs ):
|
def wrapper( self, input, *args, **kwargs ):
|
||||||
|
|
|
@ -240,7 +240,7 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||||
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||||
"""
|
"""
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user