vall-e/vall_e/models/base.py

942 lines
31 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
import math
import torch
import torch.nn.functional as F
import traceback
import numpy as np
import re
2023-08-02 21:53:35 +00:00
from typing import Literal, overload, Optional, Tuple
2023-08-02 21:53:35 +00:00
from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.nn import Embedding
2023-08-02 21:53:35 +00:00
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
2024-06-06 01:30:43 +00:00
from .arch import *
from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
2023-08-02 21:53:35 +00:00
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def _join(x: tuple[Tensor], sep: Tensor):
"""
Args:
x: (k t d)
sep: (d)
"""
ret = x[0]
for i in range(1, len(x)):
ret = torch.cat((ret, sep[None], x[i]), dim=0)
return ret
def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
"""
Args:
x_list: [(t d)]
Returns:
x: (? ? ?)
m: (? ? ?), same as x
"""
l = list(map(len, x_list))
x = rearrange(pad_sequence(x_list), pattern)
m = _create_mask(l, x_list[0].device)
m = m.t().unsqueeze(-1) # (t b 1)
m = rearrange(m, pattern)
m = m.to(x)
return x, m
# automagically parses a batch-list and returns it as a list
"""
2023-08-02 21:53:35 +00:00
class Embedding(nn.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
"""
2023-08-02 21:53:35 +00:00
# Deprecated implementation
class MultiEmbedding(nn.Module):
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
2023-09-16 05:26:13 +00:00
super().__init__()
2023-09-08 20:36:26 +00:00
self.monolithic = monolithic
2023-08-02 21:53:35 +00:00
self.max_n_levels = max_n_levels
self.n_tokens = n_tokens
2023-09-08 20:36:26 +00:00
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
2023-08-02 21:53:35 +00:00
2023-09-07 22:08:38 +00:00
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
2023-08-02 21:53:35 +00:00
if len(x_list) == 0:
return []
2023-09-08 20:36:26 +00:00
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
if self.monolithic:
2023-09-08 20:36:26 +00:00
w = self.weight[:1] if quant_levels is None else self.weight[1:]
else:
w = self.weight
2023-08-02 21:53:35 +00:00
padded_x_list = []
2023-08-02 21:53:35 +00:00
2023-09-08 20:36:26 +00:00
for i, xi in enumerate(x_list):
2023-08-02 21:53:35 +00:00
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
2023-08-02 21:53:35 +00:00
padded_x_list.append(xi.to(w))
x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x)
2023-08-02 21:53:35 +00:00
x_list = x.split([*map(len, x_list)])
return x_list
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module):
def __init__(
self,
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
):
2023-09-07 14:14:03 +00:00
super().__init__()
# array of embeddings
# proms are [0, prom_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
#
self.sums = sums
2023-09-07 14:14:03 +00:00
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
# prom
if quant_levels is None and xi.shape[-1] > 1:
if self.sums:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
else:
k = 0 # only use the most significant RVQ bin level for the input prom
x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1)
# AR resp
elif quant_levels is None or quant_levels == 0:
2024-05-31 01:50:45 +00:00
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
# NAR resp
else:
if self.sums:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
else:
k = xi.shape[-1] - 1 # only use the previous RVQ bin level for the current resp embedding
x = self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1)
return x
2023-08-02 21:53:35 +00:00
class Base(nn.Module):
@property
def causal(self) -> bool:
raise NotImplementedError
@property
def arch_type(self) -> str:
raise NotImplementedError
@property
def norm_type(self):
raise NotImplementedError
@property
def n_prom_levels(self) -> int:
raise NotImplementedError
@property
def n_resp_levels(self) -> int:
raise NotImplementedError
@property
def n_max_levels(self) -> int:
raise NotImplementedError
2023-08-19 01:58:07 +00:00
@property
def n_langs(self) -> int:
raise NotImplementedError
2023-08-19 01:58:07 +00:00
@property
def n_tasks(self) -> int:
raise NotImplementedError
2023-08-02 21:53:35 +00:00
@property
def n_tones(self) -> int:
raise NotImplementedError
@property
def recurrent_chunk_size(self) -> int:
raise NotImplementedError
@property
def rotary_embedding_base(self) -> float:
return 10000
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return False
2023-09-07 14:14:03 +00:00
@property
def version(self) -> int:
return 1
2023-09-07 14:14:03 +00:00
@property
def stop_token(self):
if not self.causal:
raise ValueError("Not using stop token!")
return self.n_audio_tokens
@property
def ignore_index(self):
return -100
def loss_factor(self, k):
2024-06-04 02:28:49 +00:00
if self.hyper_config is None:
return 1.0
2024-06-04 02:28:49 +00:00
return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0
2023-08-02 21:53:35 +00:00
def __init__(
self,
n_text_tokens: int = 256,
n_audio_tokens: int = 1024,
2023-08-02 21:53:35 +00:00
d_model: int = 512,
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
n_experts: int = 1,
l_padding: int = 0,
training = True,
config = None,
2023-08-02 21:53:35 +00:00
):
super().__init__()
self.training = training
2024-06-04 02:28:49 +00:00
self.hyper_config = config
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
self.n_text_tokens = n_text_tokens
self.n_audio_tokens = n_audio_tokens
2023-08-02 21:53:35 +00:00
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.n_experts = n_experts
self.l_padding = l_padding
2023-08-02 21:53:35 +00:00
# +1 to include the stop token
n_prom_tokens = n_audio_tokens
n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
2023-08-02 21:53:35 +00:00
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
self.tones_emb = None
self.tasks_emb = None
self.rvq_level_emb = None
if self.version == 1: # legacy
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
# [1024] * 8
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None,
2024-06-04 02:28:49 +00:00
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
)
# [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None,
2024-06-04 02:28:49 +00:00
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
)
# useless since I actually removed using these with the input processing overhaul...
if self.version >= 3:
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
if self.version >= 4:
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
2023-08-02 21:53:35 +00:00
# mamba requires this if a model does both AR and NAR tasks
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
# this ***might*** let me also unify the proms_emb and resps_embedding
if self.version >= 5:
self.rvq_level_emb = Embedding(self.n_resp_levels, d_model)
# this would be nicer to be a stop token or live inside an embedding
2023-08-02 21:53:35 +00:00
self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
2024-06-04 02:28:49 +00:00
hf_attention = self.hyper_config.attention if self.hyper_config is not None else None
2024-06-04 02:28:49 +00:00
if self.hyper_config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS:
2024-06-04 02:28:49 +00:00
self.hyper_config.attention = "flash"
elif "xformers" in AVAILABLE_ATTENTIONS:
2024-06-04 02:28:49 +00:00
self.hyper_config.attention = "xformers"
else:
2024-06-04 02:28:49 +00:00
self.hyper_config.attention = "mem_efficient"
2024-06-04 02:28:49 +00:00
if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]:
hf_attention = None
2024-06-04 02:28:49 +00:00
if self.hyper_config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
2023-08-02 21:53:35 +00:00
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
self.blocks = nn.ModuleList([TransformerBlock(
d_model=d_model,
n_heads=n_heads,
p_dropout=p_dropout if training else 0.0,
2023-08-19 01:58:07 +00:00
causal=self.causal,
2023-08-02 21:53:35 +00:00
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type in ["mistral", "mixtral"]:
if n_experts <= 1:
self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
2024-06-04 02:28:49 +00:00
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
))
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
2024-06-04 02:28:49 +00:00
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=hf_attention,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
))
2024-06-04 02:28:49 +00:00
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
2024-05-11 21:47:19 +00:00
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
2024-05-12 13:22:39 +00:00
#if training:
# self.model.training = True
elif self.arch_type == "llama":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
sliding_window=75 * 12, # 12 second context window
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
))
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=hf_attention,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
))
2024-05-10 04:15:52 +00:00
2024-06-04 02:28:49 +00:00
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
2024-05-11 21:47:19 +00:00
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
2024-05-10 04:15:52 +00:00
2024-05-12 13:22:39 +00:00
#if training:
# self.model.training = True
2023-08-02 21:53:35 +00:00
elif self.arch_type == "retnet":
kwargs = dict(
vocab_size=n_resp_tokens,
2023-08-02 21:53:35 +00:00
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
2023-08-02 21:53:35 +00:00
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout if training else 0.0,
2024-06-04 02:28:49 +00:00
checkpoint_activations=self.gradient_checkpointing,
activation_fn="gelu",
use_layernorm=self.version < 3,
use_biases=self.version < 3,
use_glu=self.version >= 3,
2023-08-02 21:53:35 +00:00
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
2023-08-02 21:53:35 +00:00
no_output_layer=True,
decoder_normalize_before=True,
rotary_embedding_base=self.rotary_embedding_base, # 10000
)
if n_experts > 1:
kwargs.update(dict(
use_xmoe=True,
moe_freq=1,
moe_expert_count=n_experts,
moe_gating_use_fp32=False,
))
self.model = RetNetDecoder(RetNetConfig(**kwargs))
elif self.arch_type == "retnet-hf":
kwargs = dict(
vocab_size=n_resp_tokens,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout if training else 0.0,
2024-06-04 02:28:49 +00:00
checkpoint_activations=self.gradient_checkpointing,
activation_fn="gelu",
use_glu=False, # self.version >= 3,
recurrent_chunk_size=self.recurrent_chunk_size if self.causal else 0,
decoder_normalize_before=True,
deepnorm=False,
subln=True,
)
self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs))
2024-05-12 13:22:39 +00:00
2024-06-04 02:28:49 +00:00
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
2024-05-12 13:22:39 +00:00
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif self.arch_type == "bitnet":
self.model = BitNetTransformer(
num_tokens=n_resp_tokens,
dim=d_model,
depth=n_layers,
heads=n_heads,
ff_mult=4,
2024-06-04 02:28:49 +00:00
gradient_checkpointing=self.gradient_checkpointing,
)
elif self.arch_type in ["mamba","mamba2"]:
self.model = MambaMixelModel(
vocab_size=n_resp_tokens,
d_model=d_model,
n_layer=n_layers*2,
d_intermediate=0,
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=True,
#attn_layer_idx=attn_layer_idx,
#attn_cfg=attn_cfg,
#initializer_cfg=initializer_cfg,
)
self.model.gradient_checkpointing = self.gradient_checkpointing
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
2023-08-04 01:26:36 +00:00
2024-06-04 02:28:49 +00:00
if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
2024-06-06 01:30:43 +00:00
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.hyper_config.attention )
2023-08-02 21:53:35 +00:00
self.classifier = nn.Linear(d_model, n_resp_tokens)
self.accuracy_metric = MulticlassAccuracy(
n_resp_tokens,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=self.ignore_index,
)
self.precision_metric = MulticlassPrecision(
n_resp_tokens,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=self.ignore_index,
)
def _forward(
2023-08-02 21:53:35 +00:00
self,
inputs,
mask = None,
state = None,
2023-08-02 21:53:35 +00:00
):
x = inputs
m = mask.squeeze(-1).int()
aux_loss = None
"""
# Broken
if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"):
# prefill
if len(state) == 0:
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
if self.arch_type == "retnet":
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
elif self.arch_type == "retnet-hf":
state = None
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
kwargs = dict(
attention_mask=m,
inputs_embeds=xi,
past_key_values=state,
use_cache=True,
forward_impl='recurrent',
# return_dict=True,
)
out = self.model(**kwargs)
state = out.past_key_values
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
"""
# HF transformer derived model
if self.arch_type in ["llama", "mistral", "mixtral"]:
kwargs = dict(
attention_mask=m,
inputs_embeds=x,
past_key_values=state,
use_cache=True,
# return_dict=True,
)
if self.n_experts > 1 and targ_list is not None:
kwargs["output_router_logits"] = True
t = self.model(**kwargs)
x = t[0]
if state is not None:
state = t[1]
if self.n_experts > 1 and targ_list is not None:
router_logits = t[-1]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
elif self.arch_type == "transformer":
2023-09-12 21:04:45 +00:00
# ensures we specify a quant_level for the transformer implementation's AdaLN
2023-09-07 14:14:03 +00:00
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
l = l.to(device)
2023-09-12 21:04:45 +00:00
# inject position information
x = self.sin_emb.add_pe(x)
# pass our inputs through the transformer
2023-08-02 21:53:35 +00:00
for block in self.blocks:
x = block(x, m, l)
2023-08-02 21:53:35 +00:00
elif self.arch_type == "retnet":
2023-09-12 21:04:45 +00:00
# pass our inputs through the RetNet
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
if _ is not None and "l_aux" in _ and self.n_experts > 1:
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
elif self.arch_type == "retnet-hf":
first = state is None or len(state) == 0
kwargs = dict(
attention_mask=m,
inputs_embeds=x if first else x[:, -1, :].unsqueeze(1),
past_key_values=None if first else state,
use_cache=True,
forward_impl='parallel' if first else 'recurrent',
return_dict=True,
)
out = self.model(**kwargs)
x = out.last_hidden_state
if state is not None:
state = out.past_key_values
elif self.arch_type in ["mamba","mamba2"]:
x = self.model( hidden_states=x )
elif self.arch_type == "bitnet":
x = self.model(x)
2023-09-12 21:04:45 +00:00
# output projection layer with masking
x = self.classifier(x) * mask
return x, state, aux_loss
def inputs(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None
):
device = text_list[0].device
batch_size = len(text_list)
inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size):
quant_level = quant_levels[i] if quant_levels is not None else 0
if text_list is not None:
inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
if proms_list is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
if resps_list is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
if targ_list is not None:
inputs[i].append( ( "targ", targ_list[i] ) )
return inputs
def inputs_to_embeddings(
self,
inputs: list,
quant_levels: Tensor | None = None
):
x_list = []
for batch_index, batch_input in enumerate(inputs):
batch = []
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
for name, input in batch_input:
embedding = None
if name == "text":
embedding = self.text_emb( input )
2024-06-05 15:30:04 +00:00
elif name == "quant_level" and self.rvq_level_emb is not None:
embedding = self.rvq_level_emb( input )
2024-06-05 15:30:04 +00:00
elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input )
elif name == "prom":
embedding = self.proms_emb( input )
2024-06-05 15:30:04 +00:00
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
embedding = self.resps_emb( input, quant_level )
else:
continue
batch.append(embedding)
x_list.append( _join( batch, self.sep ) )
return x_list
def calc_loss(
self,
inputs: list,
logits,
quant_levels: Tensor | None = None,
):
# old, "naive" way, no loss factoring
2024-06-04 02:28:49 +00:00
if not self.hyper_config.loss_factors:
target_list = []
for batch_index, batch in enumerate(inputs):
target = []
for name, input in batch:
if name == "prom":
target.append( torch.full_like(input[..., 0], self.ignore_index) )
elif name in ["text", "quant_level", "lang", "tone", "targ"]:
target.append( input )
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_list)):
if quant_levels is not None and quant_levels[i] > 0:
continue
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
# see comments for the split-loss calc cross_entropy call
if False:
target = torch.cat( target_list )
inputs = torch.cat( logits )
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
else:
self.loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch)
)
self.stats = dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch)
)
return
"""
# considerations:
# * split losses does not maintain the entire sequence
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
# + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
"""
self.loss = dict()
self.stats = dict(acc = dict())
info = {}
for i, batch in enumerate( inputs ):
quant_level = quant_levels[i] if quant_levels is not None else None
it = 0
for name, input in batch:
# do not use resp
if name == "resp":
continue
# rename to resp
if name == "targ":
name = "resp"
# select prom level
elif name == "prom" and quant_level is not None:
input = input[:, quant_level]
seq_len = input.shape[0]
logit = logits[i][it:it+seq_len]
it += seq_len + 1 # +1 to incorporate the separator
# for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if quant_level is None or quant_level == 0:
logit = logit[..., :-1, :] # get all but the final logit
input = input[..., 1:] # shift sequence to the right by one
if name not in info:
info[name] = {
"targets": [],
"logits": [],
}
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
info[name]["targets"].append( input.long() )
info[name]["logits"].append( logit )
for name, batch in info.items():
loss_factor = self.loss_factor(name)
2024-06-05 04:48:51 +00:00
if name not in ["text", "prom", "resp"]:
continue
if loss_factor == 0.0:
continue
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
# to-do: set this to a var
if False:
targets = torch.cat( batch["targets"] ).long()
inputs = torch.cat( batch["logits"] )
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
2024-06-05 04:48:51 +00:00
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
# probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch)
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch)
# accuracy sometimes breaks for mamba
# to-do: compute loss per individual batch to scale per RVQ level
"""
rvq_loss_factor = self.loss_factor("quant")
if isinstance( rvq_loss_factor, list ):
...
"""
def forward(
self,
inputs: list,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
):
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list)
# yes, there's a better way.
training = False
for batch_index, batch in enumerate(inputs):
for name, input in batch:
if name == "targ":
training = True
device = x.device
batch_size = len(x_list)
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input
shape = list(x.shape)
shape[1] = self.l_padding - shape[1] % self.l_padding
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding], dim=1)
# pad mask
shape[2] = 1
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
m = torch.cat([m, padding], dim=1)
x, state, aux_loss = self._forward(
inputs=x,
mask=m,
state=state,
)
2023-08-02 21:53:35 +00:00
# Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
2023-08-02 21:53:35 +00:00
# compute loss if the target is given
if training:
self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
# include any additional losses (for example: MoE router)
if aux_loss is not None:
self.loss["aux_loss"] = aux_loss
return (logits, state) if state is not None else logits
def sample(
self,
logits: list[Tensor],
resps_list: list[Tensor],
quant_levels: Tensor | None = None,
temperature: float = 1.0,
min_temperature: float = -1.0,
top_k: int = -100,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
repetition_penalty_decay: float = 0.0,
length_penalty: float = 0.0,
beam_width: int = 0,
mirostat: list[dict] | None = None,
):
if min_temperature < 0:
min_temperature = temperature
# (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None:
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
# (AR chunkwise) return the last chunkwise piece
elif self.causal and self.recurrent_chunk_size > 0:
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
# (AR) return just the last code
# Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously)
2023-08-02 21:53:35 +00:00
else:
logits = [ logit[-1:] for logit in logits ]
devices = [ logit.device for logit in logits ]
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
# perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
# epsilon float comparison because I don't trust Python
if abs(temperature - min_temperature) >= 0.001:
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
else:
logits = [ logit / temperature for logit in logits ]
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None:
# mirostat sampling
return [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ]
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
# returns the logit scores as well to be P-concatted with the previous scores
# to-do: not naively implement beam searching
if beam_width > 1:
candidates = top_k_logits_list( logits, beam_width )
res = [ torch.tensor(token, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
return res, scores
# and sample
return [ Categorical(logits=logit).sample() for logit in logits ]