backported old fork features (kv_cache (which looking back seems like a spook), ddim sampling, etc)
This commit is contained in:
parent
268ba17485
commit
99be487482
|
@ -37,10 +37,15 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
|
|||
|
||||
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
|
||||
- [ ] Reimplement candidate selection with the CLVP
|
||||
- [ ] Reimplement redaction with the Wav2Vec2
|
||||
- [X] Implement training support (without DLAS)
|
||||
- [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
|
||||
- [ ] Automagic offloading to CPU for unused models (for training and inferencing)
|
||||
- [X] Automagic handling of the original weights into compatible weights
|
||||
- [ ] Reimplement added features from my original fork:
|
||||
- [ ] "Better" conditioning latents calculating
|
||||
- [x] Use of KV-cache for the AR
|
||||
- [x] Re-enable DDIM sampler
|
||||
- [ ] Extend the original inference routine with additional features:
|
||||
- [ ] non-float32 / mixed precision for the entire stack
|
||||
- [x] BitsAndBytes support
|
||||
|
@ -48,10 +53,13 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
|
|||
- [x] LoRAs
|
||||
- [x] Web UI
|
||||
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
|
||||
- Although I feel a lot of its features are the wrong way to go about it.
|
||||
- [ ] Additional samplers for the autoregressive model
|
||||
- [ ] Additional samplers for the diffusion model
|
||||
- [ ] BigVGAN in place of the original vocoder
|
||||
- [ ] XFormers / flash_attention_2 for the autoregressive model
|
||||
- Beyond HF's internal implementation of handling alternative attention
|
||||
- Both the AR and diffusion models also do their own attention...
|
||||
- [ ] Some vector embedding store to find the "best" utterance to pick
|
||||
- [ ] Documentation
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ def main():
|
|||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--beam-width", type=int, default=0)
|
||||
|
||||
parser.add_argument("--diffusion-sampler", type=str, default="ddim")
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
|
@ -56,6 +58,8 @@ def main():
|
|||
#repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
|
||||
diffusion_sampler=args.diffusion_sampler
|
||||
)
|
||||
"""
|
||||
language=args.language,
|
||||
|
|
|
@ -114,6 +114,9 @@ class TTS():
|
|||
beam_width=1,
|
||||
#mirostat_tau=0,
|
||||
#mirostat_eta=0.1,
|
||||
|
||||
diffusion_sampler="ddim",
|
||||
|
||||
out_path=None
|
||||
):
|
||||
lines = text.split("\n")
|
||||
|
@ -222,9 +225,10 @@ class TTS():
|
|||
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
|
||||
|
||||
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
|
||||
mel = diffuser.p_sample_loop(
|
||||
mel = diffuser.sample_loop(
|
||||
diffusion,
|
||||
output_shape,
|
||||
sampler=diffusion_sampler,
|
||||
noise=noise,
|
||||
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
|
||||
progress=True
|
||||
|
|
|
@ -42,6 +42,16 @@ def normalization(channels):
|
|||
return GroupNorm32(groups, channels)
|
||||
|
||||
|
||||
AVAILABLE_ATTENTIONS = ["mem_efficient", "math", "sdpa"]
|
||||
|
||||
try:
|
||||
from xformers.ops import LowerTriangularMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("xformers")
|
||||
except Exception as e:
|
||||
print("Error while importing `xformers`", e)
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
||||
|
@ -51,13 +61,14 @@ class QKVAttentionLegacy(nn.Module):
|
|||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, mask=None, rel_pos=None):
|
||||
def forward(self, qkv, mask=None, rel_pos=None, mode="xformers"):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
|
@ -73,11 +84,11 @@ class QKVAttentionLegacy(nn.Module):
|
|||
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
||||
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
||||
weight = weight * mask
|
||||
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
|
|
@ -9,6 +9,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from torch import autocast
|
||||
|
||||
from .arch_utils import normalization, AttentionBlock
|
||||
|
@ -493,6 +495,16 @@ class GaussianDiffusion:
|
|||
)
|
||||
return out
|
||||
|
||||
def sample_loop(self, *args, **kwargs):
|
||||
# YUCK
|
||||
sampler = kwargs.pop("sampler").lower() if "sampler" in kwargs else "ddim"
|
||||
if sampler == 'p':
|
||||
return self.p_sample_loop(*args, **kwargs)
|
||||
if sampler == 'ddim':
|
||||
return self.ddim_sample_loop(*args, **kwargs)
|
||||
|
||||
raise RuntimeError(f"Sampler not implemented: {sampler}")
|
||||
|
||||
def p_sample(
|
||||
self,
|
||||
model,
|
||||
|
@ -780,9 +792,6 @@ class GaussianDiffusion:
|
|||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
if progress:
|
||||
# Lazy import so that we don't depend on tqdm.
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
indices = tqdm(indices, disable=not progress)
|
||||
|
||||
for i in indices:
|
||||
|
|
|
@ -11,6 +11,7 @@ from transformers.utils.model_parallel_utils import get_device_map, assert_devic
|
|||
from .arch_utils import AttentionBlock
|
||||
|
||||
from transformers import LogitsWarper
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||
|
||||
|
@ -81,15 +82,16 @@ class ResBlock(nn.Module):
|
|||
def forward(self, x):
|
||||
return F.relu(self.net(x) + x)
|
||||
|
||||
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
|
||||
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True):
|
||||
super().__init__(config)
|
||||
self.transformer = gpt
|
||||
self.text_pos_embedding = text_pos_emb
|
||||
self.embeddings = embeddings
|
||||
self.lm_head = nn.Sequential(norm, linear)
|
||||
|
||||
self.kv_cache = kv_cache
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
@ -123,8 +125,11 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
self.cached_mel_emb = mel_emb
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
|
||||
if not self.kv_cache:
|
||||
past = None
|
||||
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
@ -278,38 +283,6 @@ class LearnedPositionEmbeddings(nn.Module):
|
|||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, attention_implementation="eager"):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
use_cache=not checkpointing,
|
||||
attention_implementation=attention_implementation
|
||||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
|
||||
if checkpointing:
|
||||
gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
# Override the built in positional embeddings
|
||||
del gpt.wpe
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
||||
None, None
|
||||
|
||||
|
||||
class MelEncoder(nn.Module):
|
||||
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
||||
super().__init__()
|
||||
|
@ -341,6 +314,7 @@ class UnifiedVoice(nn.Module):
|
|||
model_dim=1024, # 512
|
||||
heads=16, # 8
|
||||
max_text_tokens=402, # 120
|
||||
max_prompt_tokens=2, # XTTS2 uses 70
|
||||
max_mel_tokens=604, # 250
|
||||
max_conditioning_inputs=2, # 1
|
||||
mel_length_compression=1024,
|
||||
|
@ -392,17 +366,48 @@ class UnifiedVoice(nn.Module):
|
|||
self.heads = heads
|
||||
self.max_mel_tokens = max_mel_tokens
|
||||
self.max_text_tokens = max_text_tokens
|
||||
self.max_prompt_tokens = max_prompt_tokens
|
||||
self.model_dim = model_dim
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||
|
||||
if use_mel_codes_as_input:
|
||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||
else:
|
||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing, attention_implementation)
|
||||
|
||||
max_mel_seq_len = self.max_mel_tokens+2+self.max_conditioning_inputs
|
||||
max_text_seq_len = self.max_text_tokens+2
|
||||
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
use_cache=not checkpointing,
|
||||
attention_implementation=attention_implementation
|
||||
)
|
||||
self.gpt = GPT2Model(gpt_config)
|
||||
|
||||
if checkpointing:
|
||||
self.gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
del self.gpt.wpe
|
||||
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Override the built in positional embeddings
|
||||
del self.gpt.wte
|
||||
self.gpt.wte = None # Built-in token embeddings are unused.
|
||||
|
||||
self.mel_pos_embedding = LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
||||
self.text_pos_embedding = LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
||||
self.mel_layer_pos_embedding = None
|
||||
self.text_layer_pos_embedding = None
|
||||
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
|
@ -421,6 +426,42 @@ class UnifiedVoice(nn.Module):
|
|||
for module in embeddings:
|
||||
module.weight.data.normal_(mean=0.0, std=.02)
|
||||
|
||||
def post_init_gpt2_config(self, kv_cache = True, use_deepspeed = False):
|
||||
seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
GPT2Config(
|
||||
vocab_size=self.max_mel_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
attn_implementation=self.attention_implementation,
|
||||
),
|
||||
self.gpt,
|
||||
self.mel_pos_embedding,
|
||||
self.mel_embedding,
|
||||
self.final_norm,
|
||||
self.mel_head,
|
||||
kv_cache=True
|
||||
)
|
||||
|
||||
# technically should already be done on the framework side, but my old fork had this here anyways
|
||||
if use_deepspeed:
|
||||
import deepspeed
|
||||
self.ds_engine = deepspeed.init_inference(
|
||||
model=self.inference_model,
|
||||
mp_size=1,
|
||||
replace_with_kernel_inject=True,
|
||||
# dtype=torch.float32
|
||||
)
|
||||
self.inference_model = self.ds_engine.module
|
||||
|
||||
self.inference_model.eval()
|
||||
self.gpt.wte = self.mel_embedding
|
||||
|
||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1,0), value=start_token)
|
||||
tar = F.pad(input, (0,1), value=stop_token)
|
||||
|
@ -547,23 +588,11 @@ class UnifiedVoice(nn.Module):
|
|||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
|
||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||
if not hasattr(self, 'inference_model'):
|
||||
# TODO: Decouple gpt_config from this inference model.
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_mel_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
attn_implementation=self.attention_implementation,
|
||||
)
|
||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
self.post_init_gpt2_config(kv_cache = kv_cache)
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
|
|
|
@ -96,6 +96,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||
parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"])
|
||||
"""
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||
|
@ -125,6 +126,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
#repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
|
||||
diffusion_sampler=args.diffusion_sampler,
|
||||
)
|
||||
|
||||
wav = wav.squeeze(0).cpu().numpy()
|
||||
|
|
Loading…
Reference in New Issue
Block a user