backported old fork features (kv_cache (which looking back seems like a spook), ddim sampling, etc)

This commit is contained in:
mrq 2024-06-19 14:49:24 -05:00
parent 268ba17485
commit 99be487482
7 changed files with 125 additions and 57 deletions

View File

@ -37,10 +37,15 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
- [ ] Reimplement candidate selection with the CLVP
- [ ] Reimplement redaction with the Wav2Vec2
- [X] Implement training support (without DLAS)
- [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
- [ ] Automagic offloading to CPU for unused models (for training and inferencing)
- [X] Automagic handling of the original weights into compatible weights
- [ ] Reimplement added features from my original fork:
- [ ] "Better" conditioning latents calculating
- [x] Use of KV-cache for the AR
- [x] Re-enable DDIM sampler
- [ ] Extend the original inference routine with additional features:
- [ ] non-float32 / mixed precision for the entire stack
- [x] BitsAndBytes support
@ -48,10 +53,13 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
- [x] LoRAs
- [x] Web UI
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
- Although I feel a lot of its features are the wrong way to go about it.
- [ ] Additional samplers for the autoregressive model
- [ ] Additional samplers for the diffusion model
- [ ] BigVGAN in place of the original vocoder
- [ ] XFormers / flash_attention_2 for the autoregressive model
- Beyond HF's internal implementation of handling alternative attention
- Both the AR and diffusion models also do their own attention...
- [ ] Some vector embedding store to find the "best" utterance to pick
- [ ] Documentation

View File

@ -22,6 +22,8 @@ def main():
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0)
parser.add_argument("--diffusion-sampler", type=str, default="ddim")
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
@ -56,6 +58,8 @@ def main():
#repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
diffusion_sampler=args.diffusion_sampler
)
"""
language=args.language,

View File

@ -114,6 +114,9 @@ class TTS():
beam_width=1,
#mirostat_tau=0,
#mirostat_eta=0.1,
diffusion_sampler="ddim",
out_path=None
):
lines = text.split("\n")
@ -222,9 +225,10 @@ class TTS():
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
mel = diffuser.p_sample_loop(
mel = diffuser.sample_loop(
diffusion,
output_shape,
sampler=diffusion_sampler,
noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
progress=True

View File

@ -42,6 +42,16 @@ def normalization(channels):
return GroupNorm32(groups, channels)
AVAILABLE_ATTENTIONS = ["mem_efficient", "math", "sdpa"]
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
@ -51,13 +61,14 @@ class QKVAttentionLegacy(nn.Module):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None, rel_pos=None):
def forward(self, qkv, mask=None, rel_pos=None, mode="xformers"):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
@ -73,11 +84,11 @@ class QKVAttentionLegacy(nn.Module):
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.

View File

@ -9,6 +9,8 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm.auto import tqdm
from torch import autocast
from .arch_utils import normalization, AttentionBlock
@ -493,6 +495,16 @@ class GaussianDiffusion:
)
return out
def sample_loop(self, *args, **kwargs):
# YUCK
sampler = kwargs.pop("sampler").lower() if "sampler" in kwargs else "ddim"
if sampler == 'p':
return self.p_sample_loop(*args, **kwargs)
if sampler == 'ddim':
return self.ddim_sample_loop(*args, **kwargs)
raise RuntimeError(f"Sampler not implemented: {sampler}")
def p_sample(
self,
model,
@ -780,9 +792,6 @@ class GaussianDiffusion:
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices, disable=not progress)
for i in indices:

View File

@ -11,6 +11,7 @@ from transformers.utils.model_parallel_utils import get_device_map, assert_devic
from .arch_utils import AttentionBlock
from transformers import LogitsWarper
from transformers import GPT2Config, GPT2Model
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
@ -81,15 +82,16 @@ class ResBlock(nn.Module):
def forward(self, x):
return F.relu(self.net(x) + x)
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings
self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache
# Model parallel
self.model_parallel = False
self.device_map = None
@ -123,8 +125,11 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
self.cached_mel_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
if not self.kv_cache:
past = None
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
@ -278,38 +283,6 @@ class LearnedPositionEmbeddings(nn.Module):
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, attention_implementation="eager"):
"""
GPT-2 implemented by the HuggingFace library.
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(
vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
use_cache=not checkpointing,
attention_implementation=attention_implementation
)
gpt = GPT2Model(gpt_config)
if checkpointing:
gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
# Override the built in positional embeddings
del gpt.wpe
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
None, None
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
super().__init__()
@ -341,6 +314,7 @@ class UnifiedVoice(nn.Module):
model_dim=1024, # 512
heads=16, # 8
max_text_tokens=402, # 120
max_prompt_tokens=2, # XTTS2 uses 70
max_mel_tokens=604, # 250
max_conditioning_inputs=2, # 1
mel_length_compression=1024,
@ -392,17 +366,48 @@ class UnifiedVoice(nn.Module):
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.max_prompt_tokens = max_prompt_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing, attention_implementation)
max_mel_seq_len = self.max_mel_tokens+2+self.max_conditioning_inputs
max_text_seq_len = self.max_text_tokens+2
gpt_config = GPT2Config(
vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
use_cache=not checkpointing,
attention_implementation=attention_implementation
)
self.gpt = GPT2Model(gpt_config)
if checkpointing:
self.gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Override the built in positional embeddings
del self.gpt.wte
self.gpt.wte = None # Built-in token embeddings are unused.
self.mel_pos_embedding = LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
self.text_pos_embedding = LearnedPositionEmbeddings(max_text_seq_len, model_dim)
self.mel_layer_pos_embedding = None
self.text_layer_pos_embedding = None
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
@ -421,6 +426,42 @@ class UnifiedVoice(nn.Module):
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, kv_cache = True, use_deepspeed = False):
seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
self.inference_model = GPT2InferenceModel(
GPT2Config(
vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True,
attn_implementation=self.attention_implementation,
),
self.gpt,
self.mel_pos_embedding,
self.mel_embedding,
self.final_norm,
self.mel_head,
kv_cache=True
)
# technically should already be done on the framework side, but my old fork had this here anyways
if use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(
model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
# dtype=torch.float32
)
self.inference_model = self.ds_engine.module
self.inference_model.eval()
self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
@ -547,23 +588,11 @@ class UnifiedVoice(nn.Module):
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True,
attn_implementation=self.attention_implementation,
)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
self.post_init_gpt2_config(kv_cache = kv_cache)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)

View File

@ -96,6 +96,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"])
"""
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
@ -125,6 +126,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
#repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
diffusion_sampler=args.diffusion_sampler,
)
wav = wav.squeeze(0).cpu().numpy()