backported old fork features (kv_cache (which looking back seems like a spook), ddim sampling, etc)
This commit is contained in:
parent
268ba17485
commit
99be487482
|
@ -37,10 +37,15 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
|
||||||
|
|
||||||
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
|
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
|
||||||
- [ ] Reimplement candidate selection with the CLVP
|
- [ ] Reimplement candidate selection with the CLVP
|
||||||
|
- [ ] Reimplement redaction with the Wav2Vec2
|
||||||
- [X] Implement training support (without DLAS)
|
- [X] Implement training support (without DLAS)
|
||||||
- [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
|
- [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
|
||||||
- [ ] Automagic offloading to CPU for unused models (for training and inferencing)
|
- [ ] Automagic offloading to CPU for unused models (for training and inferencing)
|
||||||
- [X] Automagic handling of the original weights into compatible weights
|
- [X] Automagic handling of the original weights into compatible weights
|
||||||
|
- [ ] Reimplement added features from my original fork:
|
||||||
|
- [ ] "Better" conditioning latents calculating
|
||||||
|
- [x] Use of KV-cache for the AR
|
||||||
|
- [x] Re-enable DDIM sampler
|
||||||
- [ ] Extend the original inference routine with additional features:
|
- [ ] Extend the original inference routine with additional features:
|
||||||
- [ ] non-float32 / mixed precision for the entire stack
|
- [ ] non-float32 / mixed precision for the entire stack
|
||||||
- [x] BitsAndBytes support
|
- [x] BitsAndBytes support
|
||||||
|
@ -48,10 +53,13 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
|
||||||
- [x] LoRAs
|
- [x] LoRAs
|
||||||
- [x] Web UI
|
- [x] Web UI
|
||||||
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
|
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
|
||||||
|
- Although I feel a lot of its features are the wrong way to go about it.
|
||||||
- [ ] Additional samplers for the autoregressive model
|
- [ ] Additional samplers for the autoregressive model
|
||||||
- [ ] Additional samplers for the diffusion model
|
- [ ] Additional samplers for the diffusion model
|
||||||
- [ ] BigVGAN in place of the original vocoder
|
- [ ] BigVGAN in place of the original vocoder
|
||||||
- [ ] XFormers / flash_attention_2 for the autoregressive model
|
- [ ] XFormers / flash_attention_2 for the autoregressive model
|
||||||
|
- Beyond HF's internal implementation of handling alternative attention
|
||||||
|
- Both the AR and diffusion models also do their own attention...
|
||||||
- [ ] Some vector embedding store to find the "best" utterance to pick
|
- [ ] Some vector embedding store to find the "best" utterance to pick
|
||||||
- [ ] Documentation
|
- [ ] Documentation
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,8 @@ def main():
|
||||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||||
parser.add_argument("--beam-width", type=int, default=0)
|
parser.add_argument("--beam-width", type=int, default=0)
|
||||||
|
|
||||||
|
parser.add_argument("--diffusion-sampler", type=str, default="ddim")
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path, default=None)
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
|
@ -56,6 +58,8 @@ def main():
|
||||||
#repetition_penalty_decay=args.repetition_penalty_decay,
|
#repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
|
|
||||||
|
diffusion_sampler=args.diffusion_sampler
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
language=args.language,
|
language=args.language,
|
||||||
|
|
|
@ -114,6 +114,9 @@ class TTS():
|
||||||
beam_width=1,
|
beam_width=1,
|
||||||
#mirostat_tau=0,
|
#mirostat_tau=0,
|
||||||
#mirostat_eta=0.1,
|
#mirostat_eta=0.1,
|
||||||
|
|
||||||
|
diffusion_sampler="ddim",
|
||||||
|
|
||||||
out_path=None
|
out_path=None
|
||||||
):
|
):
|
||||||
lines = text.split("\n")
|
lines = text.split("\n")
|
||||||
|
@ -222,9 +225,10 @@ class TTS():
|
||||||
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
|
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
|
||||||
|
|
||||||
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
|
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
|
||||||
mel = diffuser.p_sample_loop(
|
mel = diffuser.sample_loop(
|
||||||
diffusion,
|
diffusion,
|
||||||
output_shape,
|
output_shape,
|
||||||
|
sampler=diffusion_sampler,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
|
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
|
||||||
progress=True
|
progress=True
|
||||||
|
|
|
@ -42,6 +42,16 @@ def normalization(channels):
|
||||||
return GroupNorm32(groups, channels)
|
return GroupNorm32(groups, channels)
|
||||||
|
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS = ["mem_efficient", "math", "sdpa"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers.ops import LowerTriangularMask
|
||||||
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS.append("xformers")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while importing `xformers`", e)
|
||||||
|
|
||||||
class QKVAttentionLegacy(nn.Module):
|
class QKVAttentionLegacy(nn.Module):
|
||||||
"""
|
"""
|
||||||
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
||||||
|
@ -51,13 +61,14 @@ class QKVAttentionLegacy(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
|
||||||
def forward(self, qkv, mask=None, rel_pos=None):
|
def forward(self, qkv, mask=None, rel_pos=None, mode="xformers"):
|
||||||
"""
|
"""
|
||||||
Apply QKV attention.
|
Apply QKV attention.
|
||||||
|
|
||||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||||
:return: an [N x (H * C) x T] tensor after attention.
|
:return: an [N x (H * C) x T] tensor after attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bs, width, length = qkv.shape
|
bs, width, length = qkv.shape
|
||||||
assert width % (3 * self.n_heads) == 0
|
assert width % (3 * self.n_heads) == 0
|
||||||
ch = width // (3 * self.n_heads)
|
ch = width // (3 * self.n_heads)
|
||||||
|
@ -73,11 +84,11 @@ class QKVAttentionLegacy(nn.Module):
|
||||||
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
||||||
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
||||||
weight = weight * mask
|
weight = weight * mask
|
||||||
|
|
||||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||||
|
|
||||||
return a.reshape(bs, -1, length)
|
return a.reshape(bs, -1, length)
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
An attention block that allows spatial positions to attend to each other.
|
An attention block that allows spatial positions to attend to each other.
|
||||||
|
|
|
@ -9,6 +9,8 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
|
||||||
from .arch_utils import normalization, AttentionBlock
|
from .arch_utils import normalization, AttentionBlock
|
||||||
|
@ -493,6 +495,16 @@ class GaussianDiffusion:
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def sample_loop(self, *args, **kwargs):
|
||||||
|
# YUCK
|
||||||
|
sampler = kwargs.pop("sampler").lower() if "sampler" in kwargs else "ddim"
|
||||||
|
if sampler == 'p':
|
||||||
|
return self.p_sample_loop(*args, **kwargs)
|
||||||
|
if sampler == 'ddim':
|
||||||
|
return self.ddim_sample_loop(*args, **kwargs)
|
||||||
|
|
||||||
|
raise RuntimeError(f"Sampler not implemented: {sampler}")
|
||||||
|
|
||||||
def p_sample(
|
def p_sample(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
|
@ -780,9 +792,6 @@ class GaussianDiffusion:
|
||||||
indices = list(range(self.num_timesteps))[::-1]
|
indices = list(range(self.num_timesteps))[::-1]
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
# Lazy import so that we don't depend on tqdm.
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
indices = tqdm(indices, disable=not progress)
|
indices = tqdm(indices, disable=not progress)
|
||||||
|
|
||||||
for i in indices:
|
for i in indices:
|
||||||
|
|
|
@ -11,6 +11,7 @@ from transformers.utils.model_parallel_utils import get_device_map, assert_devic
|
||||||
from .arch_utils import AttentionBlock
|
from .arch_utils import AttentionBlock
|
||||||
|
|
||||||
from transformers import LogitsWarper
|
from transformers import LogitsWarper
|
||||||
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
|
||||||
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||||
|
|
||||||
|
@ -81,15 +82,16 @@ class ResBlock(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return F.relu(self.net(x) + x)
|
return F.relu(self.net(x) + x)
|
||||||
|
|
||||||
|
|
||||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
|
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = gpt
|
self.transformer = gpt
|
||||||
self.text_pos_embedding = text_pos_emb
|
self.text_pos_embedding = text_pos_emb
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
self.lm_head = nn.Sequential(norm, linear)
|
self.lm_head = nn.Sequential(norm, linear)
|
||||||
|
|
||||||
|
self.kv_cache = kv_cache
|
||||||
|
|
||||||
# Model parallel
|
# Model parallel
|
||||||
self.model_parallel = False
|
self.model_parallel = False
|
||||||
self.device_map = None
|
self.device_map = None
|
||||||
|
@ -123,8 +125,11 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
self.cached_mel_emb = mel_emb
|
self.cached_mel_emb = mel_emb
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
|
|
||||||
|
if not self.kv_cache:
|
||||||
|
past = None
|
||||||
|
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
@ -278,38 +283,6 @@ class LearnedPositionEmbeddings(nn.Module):
|
||||||
def get_fixed_embedding(self, ind, dev):
|
def get_fixed_embedding(self, ind, dev):
|
||||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, attention_implementation="eager"):
|
|
||||||
"""
|
|
||||||
GPT-2 implemented by the HuggingFace library.
|
|
||||||
"""
|
|
||||||
from transformers import GPT2Config, GPT2Model
|
|
||||||
gpt_config = GPT2Config(
|
|
||||||
vocab_size=256, # Unused.
|
|
||||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
|
||||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
|
||||||
n_embd=model_dim,
|
|
||||||
n_layer=layers,
|
|
||||||
n_head=heads,
|
|
||||||
use_cache=not checkpointing,
|
|
||||||
attention_implementation=attention_implementation
|
|
||||||
)
|
|
||||||
gpt = GPT2Model(gpt_config)
|
|
||||||
|
|
||||||
if checkpointing:
|
|
||||||
gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
|
||||||
use_reentrant=False
|
|
||||||
))
|
|
||||||
|
|
||||||
# Override the built in positional embeddings
|
|
||||||
del gpt.wpe
|
|
||||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
|
||||||
# Built-in token embeddings are unused.
|
|
||||||
del gpt.wte
|
|
||||||
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
|
||||||
None, None
|
|
||||||
|
|
||||||
|
|
||||||
class MelEncoder(nn.Module):
|
class MelEncoder(nn.Module):
|
||||||
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -341,6 +314,7 @@ class UnifiedVoice(nn.Module):
|
||||||
model_dim=1024, # 512
|
model_dim=1024, # 512
|
||||||
heads=16, # 8
|
heads=16, # 8
|
||||||
max_text_tokens=402, # 120
|
max_text_tokens=402, # 120
|
||||||
|
max_prompt_tokens=2, # XTTS2 uses 70
|
||||||
max_mel_tokens=604, # 250
|
max_mel_tokens=604, # 250
|
||||||
max_conditioning_inputs=2, # 1
|
max_conditioning_inputs=2, # 1
|
||||||
mel_length_compression=1024,
|
mel_length_compression=1024,
|
||||||
|
@ -392,17 +366,48 @@ class UnifiedVoice(nn.Module):
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_text_tokens = max_text_tokens
|
self.max_text_tokens = max_text_tokens
|
||||||
|
self.max_prompt_tokens = max_prompt_tokens
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||||
|
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||||
else:
|
else:
|
||||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
|
||||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing, attention_implementation)
|
max_mel_seq_len = self.max_mel_tokens+2+self.max_conditioning_inputs
|
||||||
|
max_text_seq_len = self.max_text_tokens+2
|
||||||
|
|
||||||
|
gpt_config = GPT2Config(
|
||||||
|
vocab_size=256, # Unused.
|
||||||
|
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||||
|
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||||
|
n_embd=model_dim,
|
||||||
|
n_layer=layers,
|
||||||
|
n_head=heads,
|
||||||
|
use_cache=not checkpointing,
|
||||||
|
attention_implementation=attention_implementation
|
||||||
|
)
|
||||||
|
self.gpt = GPT2Model(gpt_config)
|
||||||
|
|
||||||
|
if checkpointing:
|
||||||
|
self.gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
|
use_reentrant=False
|
||||||
|
))
|
||||||
|
|
||||||
|
del self.gpt.wpe
|
||||||
|
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Override the built in positional embeddings
|
||||||
|
del self.gpt.wte
|
||||||
|
self.gpt.wte = None # Built-in token embeddings are unused.
|
||||||
|
|
||||||
|
self.mel_pos_embedding = LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
||||||
|
self.text_pos_embedding = LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
||||||
|
self.mel_layer_pos_embedding = None
|
||||||
|
self.text_layer_pos_embedding = None
|
||||||
|
|
||||||
if train_solo_embeddings:
|
if train_solo_embeddings:
|
||||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
|
@ -421,6 +426,42 @@ class UnifiedVoice(nn.Module):
|
||||||
for module in embeddings:
|
for module in embeddings:
|
||||||
module.weight.data.normal_(mean=0.0, std=.02)
|
module.weight.data.normal_(mean=0.0, std=.02)
|
||||||
|
|
||||||
|
def post_init_gpt2_config(self, kv_cache = True, use_deepspeed = False):
|
||||||
|
seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
|
||||||
|
self.inference_model = GPT2InferenceModel(
|
||||||
|
GPT2Config(
|
||||||
|
vocab_size=self.max_mel_tokens,
|
||||||
|
n_positions=seq_length,
|
||||||
|
n_ctx=seq_length,
|
||||||
|
n_embd=self.model_dim,
|
||||||
|
n_layer=self.layers,
|
||||||
|
n_head=self.heads,
|
||||||
|
gradient_checkpointing=False,
|
||||||
|
use_cache=True,
|
||||||
|
attn_implementation=self.attention_implementation,
|
||||||
|
),
|
||||||
|
self.gpt,
|
||||||
|
self.mel_pos_embedding,
|
||||||
|
self.mel_embedding,
|
||||||
|
self.final_norm,
|
||||||
|
self.mel_head,
|
||||||
|
kv_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# technically should already be done on the framework side, but my old fork had this here anyways
|
||||||
|
if use_deepspeed:
|
||||||
|
import deepspeed
|
||||||
|
self.ds_engine = deepspeed.init_inference(
|
||||||
|
model=self.inference_model,
|
||||||
|
mp_size=1,
|
||||||
|
replace_with_kernel_inject=True,
|
||||||
|
# dtype=torch.float32
|
||||||
|
)
|
||||||
|
self.inference_model = self.ds_engine.module
|
||||||
|
|
||||||
|
self.inference_model.eval()
|
||||||
|
self.gpt.wte = self.mel_embedding
|
||||||
|
|
||||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
inp = F.pad(input, (1,0), value=start_token)
|
inp = F.pad(input, (1,0), value=start_token)
|
||||||
tar = F.pad(input, (0,1), value=stop_token)
|
tar = F.pad(input, (0,1), value=stop_token)
|
||||||
|
@ -547,23 +588,11 @@ class UnifiedVoice(nn.Module):
|
||||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
||||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
|
||||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
# TODO: Decouple gpt_config from this inference model.
|
# TODO: Decouple gpt_config from this inference model.
|
||||||
gpt_config = GPT2Config(
|
self.post_init_gpt2_config(kv_cache = kv_cache)
|
||||||
vocab_size=self.max_mel_tokens,
|
|
||||||
n_positions=seq_length,
|
|
||||||
n_ctx=seq_length,
|
|
||||||
n_embd=self.model_dim,
|
|
||||||
n_layer=self.layers,
|
|
||||||
n_head=self.heads,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
use_cache=True,
|
|
||||||
attn_implementation=self.attention_implementation,
|
|
||||||
)
|
|
||||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
|
||||||
self.gpt.wte = self.mel_embedding
|
|
||||||
|
|
||||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
|
|
|
@ -96,6 +96,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||||
|
parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"])
|
||||||
"""
|
"""
|
||||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||||
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||||
|
@ -125,6 +126,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
#repetition_penalty_decay=args.repetition_penalty_decay,
|
#repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
|
|
||||||
|
diffusion_sampler=args.diffusion_sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
wav = wav.squeeze(0).cpu().numpy()
|
wav = wav.squeeze(0).cpu().numpy()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user