vall-e/vall_e/models/base_v2.py
2025-03-31 21:11:43 -05:00

1350 lines
43 KiB
Python

import math
import torch
import torch.nn.functional as F
import random
import numpy as np
import re
from time import perf_counter
from collections import namedtuple
from typing import Literal, overload, Optional, Tuple
from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import *
from ..utils import ml, clamp, mean, logit_normalization
from ..samplers import *
# yuck, kind of needed
from ..data import get_task_symmap
import logging
_logger = logging.getLogger(__name__)
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states'])
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
"""
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
"""
summed_embeddings_task = [ "stt" ]
special_tasks = [ "len", "stt", "phn", "text" ]
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
task_outputs = {
"tts": "resp",
"ns": "resp",
"sr": "resp",
"stt": "phn",
"len": "len",
"phn": "phn",
"text": "text",
}
def _dropout_mask( input, p ):
return (torch.rand(input.shape[0], device=input.device) < p)
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def _join(x: tuple[Tensor], sep: Tensor):
ret = x[0]
for i in range(1, len(x)):
ret = torch.cat((ret, sep[None], x[i]), dim=0)
return ret
def list_to_tensor(x_list: list[Tensor]):
l = list(map(len, x_list))
x = pad_sequence(x_list, batch_first=True)
m = _create_mask(l, x_list[0].device)
m = m.to(x).int()
return x, m
def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
x = x.clone().detach()
levels = x.shape[-1]
for level in range( levels ):
lhs = dropout_token if not swapped else x[..., level]
rhs = x[..., level] if not swapped else dropout_token
x[..., level] = torch.where( dropout_mask, lhs, rhs )
return x
# aims to properly encode token sequences into an embedding
# despite being named for FSQ codecs, this works for RVQ codecs
class FiniteAudioEncoder(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
use_ln: bool = True, # whether to perform a post-embedding pre-norm or not (I'm not sure if this is redundant)
use_ffn: bool = True, # whether to employ a residual feed forward network or not
d_model: int = None,
d_ffn: int = 2, # feed forward expansion value
):
super().__init__()
if not d_model:
d_model = token_dim
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
if use_ffn:
self.proj = nn.Sequential(
nn.Linear(token_dim, token_dim * d_ffn),
nn.GELU(),
nn.Linear(token_dim * d_ffn, d_model),
)
elif token_dim != d_model:
self.proj = nn.Linear(token_dim, d_model)
else:
self.proj = nn.Identity()
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
# explicit initialization
for emb in self.embs:
torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02)
self.use_ffn = use_ffn
if use_ffn:
nn.init.xavier_uniform_(self.proj[0].weight)
nn.init.xavier_uniform_(self.proj[2].weight)
nn.init.zeros_(self.proj[0].bias)
nn.init.zeros_(self.proj[2].bias)
elif token_dim != d_model:
nn.init.xavier_uniform_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
x = x + self.pos_embedding
x = self.norm(x)
if self.use_ffn:
x = x + self.proj( x )
else:
x = self.proj( x )
weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1)
x = (x * weights).sum(dim=1)
return x
# aims to decode the last hidden state into independent codebooks
# uses an MLP instead of Attn since it's not residual I guess (the problem with caving to consult claude-3-5-sonnet is that it will blindly agree with you if you suggest anything)
# optional per-level LN, might be beneficial
class FiniteAudioDecoder(nn.Module):
def __init__(
self,
d_model: int,
vocab_size: int,
n_levels: int,
d_ffn: int = 2, # feed forward expansion value
use_ln: bool = True, # perform layer normalization here
use_ffn: bool = True, # use a feed forward network post-norm pre-classifier
shared_levels: bool = False, # whether to have one set of weights for all codebook levels, or separate weights for each layer
):
super().__init__()
self.n_levels = n_levels
self.shared_levels = shared_levels
if use_ffn:
if not shared_levels:
self.head = nn.ModuleList([nn.Sequential(
# ln
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
# ffn
nn.Linear(d_model, d_ffn * d_model),
nn.GELU(),
nn.Linear(d_ffn * d_model, d_model),
# head
nn.Linear(d_model, vocab_size)
) for _ in range(n_levels)])
else:
self.head = nn.Sequential(
# ffn
nn.Linear(d_model, d_ffn * d_model),
nn.GELU(),
nn.Linear(d_ffn * d_model, d_model),
# head
nn.Linear(d_model, vocab_size * n_levels)
)
else:
if not shared_levels:
self.head = nn.ModuleList([nn.Sequential(
# ln
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
# head
nn.Linear(d_model, vocab_size)
) for _ in range(n_levels)])
else:
self.head = nn.Sequential(
# head
nn.Linear(d_model, vocab_size * n_levels)
)
def forward(self, x: Tensor) -> Tensor:
batch_size, seq_len, _ = x.shape
if not self.shared_levels:
x = torch.stack([head(x) for head in self.head], dim=1)
else:
x = self.head(x)
x = x.view(batch_size, seq_len, self.n_levels, -1)
x = x.transpose(1, 2)
return x
# the Residual variant doesn't seem to work well
# the Finite variant unironically works well for residual codecs
AudioEncoder = FiniteAudioEncoder
AudioDecoder = FiniteAudioDecoder
# handles simple output projections into logits for other tasks
class AuxDecoder(nn.Module):
def __init__(
self,
d_model,
vocab_size,
name = None,
):
super().__init__()
self.name = name
self.head = nn.Linear( d_model, vocab_size )
def forward(self, x: Tensor ) -> Tensor:
x = self.head( x )
return x
class Base_V2(nn.Module):
def loss_factor(self, k):
if self.config is None:
return 1.0
return self.config.loss_factor(k)
def _prune(self, l: Tensor, stop = None):
if stop is None:
stop = self.stop_token
indices = (l == stop).nonzero()
if len(indices) == 0:
return l
return l[: indices.min().item()]
def __init__(
self,
n_phn_tokens: int = 256,
n_audio_tokens: int = 1024,
n_text_tokens: int = 8575,
d_model: int = 512,
d_ffn: int = 4,
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
n_experts: int = 1,
l_padding: int = 0,
training = True,
attention = None,
config = None,
):
super().__init__()
if not attention:
attention = config.attention if config is not None else "auto"
n_resp_levels = config.resp_levels if config is not None else 8
attention_backend = attention
unified_position_ids = config.experimental.unified_position_ids if config is not None else True
noncausal_masks = config.experimental.noncausal_masks if config is not None else False
max_position_embeddings = config.experimental.max_position_embeddings if config is not None else (75 * 60 * 5)
masking_ratio = config.experimental.masking_ratio if config is not None else False
ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
len_parallel_training = config.experimental.len_parallel_training if config is not None else False
len_use_logits = config.experimental.len_use_logits if config is not None else True
predict_causally = config.experimental.predict_causally if config is not None else False
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
len_loss_factor = config.experimental.len_loss_factor if config is not None else 0.00001
logit_normalization = config.experimental.logit_normalization if config is not None else 0
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True
parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0
n_vocab = 256
n_tasks = config.tasks if config is not None else 8
n_langs = config.langs if config is not None else 2
n_tones = config.tones if config is not None else 1
if audio_level_loss_factors == "auto":
audio_level_loss_factors = "normal" if n_audio_tokens == 1000 else "decreasing"
if audio_level_loss_factors == "decreasing":
audio_level_loss_factors = [1.0 / (i + 1) for i in range(n_resp_levels)]
elif audio_level_loss_factors == "normal":
if n_resp_levels == 8:
audio_level_loss_factors = [0.5, 0.625, 0.75, 0.875, 0.875, 0.75, 0.625, 0.5]
else:
center = n_resp_levels // 2
audio_level_loss_factors = [1.0 - abs(i - center) / n_resp_levels for i in range(n_resp_levels)]
# to-do: proper cirriculum
# prioritizes midrange, maybe good for epoch 0?
# [0.5, 0.625, 0.75, 0.875, 0.875, 0.75, 0.625, 0.5]
# deprioritizes midrange, good for epoch 1?
# [0.875, 0.75, 0.625, 0.5, 0.5, 0.625, 0.75, 0.875]
elif audio_level_loss_factors == "equal":
audio_level_loss_factors = [1.0 for _ in range(n_resp_levels)]
if attention_backend == "auto":
attention_backend = "sdpa"
hf_attention = attention_backend
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
if attention_backend not in HF_ATTENTIONS:
hf_attention = None
if attention_backend not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
self.training = training
self.teaching = False
self.config = config
self.n_phn_tokens = n_phn_tokens
self.n_audio_tokens = n_audio_tokens
self.n_text_tokens = n_text_tokens
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.n_experts = n_experts
self.l_padding = l_padding
self.ignore_index = -100
self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels
self.n_max_levels = self.config.max_levels if self.config else n_resp_levels
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
self.use_streamlined_calc_loss = True
self.stop_token = self.n_audio_tokens
self.mask_token = self.stop_token + 1
self.causal = True
self.version = self.config.version if self.config is not None else 5
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
self.arch_type = self.config.arch_type if self.config is not None else "llama"
# check if requested arch is unavailable
if self.arch_type in ERROR_ARCHES:
raise ERROR_ARCHES[self.arch_type]
# crunge
if self.config is not None and config.teacher:
self.teaching = True
self.training = False
self.predict_causally = predict_causally
self.resp_parallel_training = resp_parallel_training
self.len_parallel_training = len_parallel_training
self.len_use_logits = len_use_logits
self.unified_position_ids = unified_position_ids
self.inject_timestep_embedding = False # results in bad output
self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.audio_level_loss_factors = audio_level_loss_factors
self.len_loss_factor = len_loss_factor
self.logit_normalization = False # this actually kills the model's demasking capabilities
self.use_segmented_attention_mask = use_segmented_attention_mask
self.use_sliding_attention_mask = use_sliding_attention_mask
self.parallel_attention_mask_dropout = parallel_attention_mask_dropout
self.sep = nn.Parameter(torch.randn(d_model))
self.phn_emb = ml.Embedding(n_phn_tokens, d_model)
self.text_emb = ml.Embedding(n_text_tokens, d_model)
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
self.len_emb = None # ml.Embedding(11, d_model)
self.audio_emb = None
self.proms_emb = None
self.resps_emb = None
if monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.audio_decoder = AudioDecoder(
d_model,
(n_audio_tokens + 1),
self.n_resp_levels,
use_ln=per_level_normalization,
)
self.len_decoder = AuxDecoder( d_model, 1 if not len_use_logits else (10 * 5) )
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
# override any requested padding size
if attention_backend == "flash_attn_v100":
self.l_padding = 32
elif attention_backend == "fused_attn":
self.l_padding = 128
if self.arch_type in ["none"]:
self.model = None
elif self.arch_type in ["llama"]:
self.model_config = LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*d_ffn,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
#gradient_checkpointing=self.gradient_checkpointing,
# extra parameters
output_norm = not per_level_normalization, # moves the LN out to the decoder
attn_mode = attention_backend,
causal = self.causal,
)
self.model = LlamaModel(self.model_config)
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if hasattr( self.model, "embeddings" ):
del self.model.embeddings
def _forward(
self,
inputs,
mask = None,
is_causal = None,
position_ids = None,
state = None,
output_attentions = False,
output_hidden_states = False,
):
x = inputs
m = mask #.squeeze(-1).int()
aux_loss = None
attentions = None
hidden_states = None
if self.arch_type in ["none"] or self.model is None:
...
# HF transformer derived model
elif self.arch_type in ["llama"]:
kwargs = dict(
inputs_embeds=x,
attention_mask=m,
past_key_values=state,
position_ids=position_ids,
use_cache=False, # not self.training,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
is_causal=is_causal,
)
if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True
output = self.model(**kwargs)
x = output["last_hidden_state"]
# to-do: figure out why KV caching doesn't work
#if not self.training:
if state is not None:
state = output["past_key_values"]
if output_attentions:
attentions = output["attentions"]
if output_hidden_states:
hidden_states = output["hidden_states"]
if self.n_experts > 1 and self.training:
router_logits = output["router_logits"]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m )
else:
hidden_states = self.model(x)
# process it into a format that I like
if output_hidden_states:
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
hidden_states = [ state for state in hidden_states[1:] ]
# apply normalization to these states (to-do: check if this matters)
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
return Logits(x, state, inputs, aux_loss, attentions, hidden_states)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs(
self,
phns_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
task_list: list[str] | None = None,
time_list: list[Tensor] | None = None,
quant_levels: int | list[int] | Tensor | None = None
):
if phns_list and phns_list[0] is not None:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list and text_list[0] is not None:
device = text_list[0].device
batch_size = len(text_list)
elif proms_list and proms_list[0] is not None:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list and resps_list[0] is not None:
device = resps_list[0].device
batch_size = len(resps_list)
inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size):
quant_level = quant_levels[i] if quant_levels is not None else 0
task_type = task_list[i] if task_list is not None else "tts"
timestep = time_list[i] if time_list is not None else None
classifier_level = None
# insert task type as a string
inputs[i].append( ( "task", task_type ) )
# to-do: maybe not split the below blocks up
# might be beneficial in the event I need to use a difference sequence, such as STT tasks
# Base-line TTS task
# Sequence: <phn><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if task_type in get_task_symmap() and task_type not in special_tasks:
# insert the phn prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
elif text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert input audio prompt
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
# insert tone token if we're trained for it
if tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# insert timestep token
if timestep is not None:
p = self.masking_ratio
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if self.training and p > 0:
dropout_mask = _dropout_mask( resps_list[i], p )
inputs[i].append( ("dropout_mask", dropout_mask ) )
# insert the current output response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
# Sequence: <phn><sep><rvq lvl><prom><sep><len>
elif task_type == "len":
# insert the phn prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
elif text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert input audio prompt
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
# insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
inputs[i].append( ("classifier_level", "len") )
# Speech-to-Text prediction task
# Sequence: <resp><sep><rvq lvl><sep><phn>
elif task_type == "stt":
# insert the input response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# insert lang token if we're trained for it
if lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert the output phn prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
inputs[i].append( ("classifier_level", "phn") )
# Text phonemizing task
# Sequence: <text><sep><lang><sep><phonemes>
elif task_type == "phn":
# insert the phn prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert the phn prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
inputs[i].append( ("classifier_level", "phn") )
# Text de-phonemizing task
# Sequence: <text><sep><lang><sep><phonemes>
elif task_type == "text":
# insert the phn prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
# insert lang token if we're trained for it
if lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert the phn prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ("classifier_level", "text") )
else:
raise Exception(f'Unrecognized task: {task_type}')
return inputs
def inputs_to_embeddings(
self,
inputs: list,
quant_levels: int | list[int] | Tensor | None = None
):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding( input, quant_level ):
if isinstance(input, str):
return self.tasks_emb( torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) )
if self.audio_emb is not None:
return self.audio_emb( input )
return self.proms_emb( input )
x_list = []
for batch_index, batch_input in enumerate(inputs):
batch = []
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
task_type = "tts"
input_prom = None
classifier_level = None
dropout_mask = None
timestep = None
# pre-iterate
for name, input in batch_input:
if name == "classifier_level":
classifier_level = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "timestep":
timestep = input
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
embedding = None
# is already an embedding
if name == "task":
# noop
# *maybe* inject a token for specifying task type
task_type = input
continue
elif name == "phn":
embedding = self.phn_emb( input )
device = embedding.device
elif name == "text" and self.text_emb is not None:
embedding = self.text_emb( input )
device = embedding.device
elif name == "quant_level" and self.rvq_l_emb is not None:
embedding = self.rvq_l_emb( input )
elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input )
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
if self.audio_emb is not None:
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
else:
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
elif name == "timestep" and self.time_emb is not None:
embedding = self.time_emb( input )
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )
else:
# should probably raise an exception so things aren't processed silently
continue
batch.append(embedding)
# needed, cringe
if task_type == "len":
#batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] )
x_list.append( _join( batch, self.sep ) )
return x_list
# get an attribute from a given input list
def get_input(
self,
inputs,
name,
at=None,
):
find_all = at is None
res = [] if at is None else None
for batch_index, batch_input in enumerate(inputs):
if not find_all and batch_index != at:
continue
for n, input in batch_input:
if n != name:
continue
if not find_all:
return input
res.append( input )
return res
# creates position ids from a given input list
# if not unified_position_ids, then each input segment will have its own sequence
def inputs_to_position_ids(
self,
inputs: list,
mask: Tensor,
):
device = mask.device
# shamelessly grabbed from modeling_llama.py
ids = mask.long().cumsum(-1) - 1
ids.masked_fill_( mask == 0, 1 )
# there's a better way
if not self.unified_position_ids:
x_list = []
def get_input_token_length( name, input, task ):
# task token
if isinstance(input, str):
return 1
# list of tokens
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] )
# ending input will not have a separator later
return input.shape[0]
for batch_index, batch_input in enumerate(inputs):
# pre-iterate
task = "tts"
for name, input in batch_input:
if name == "task":
task = input
batch = torch.cat( [
torch.tensor([*range(get_input_token_length(name, input, task) + (1 if name != task_outputs.get(task, name) else 0))], device=device, dtype=torch.int32)
for name, input in batch_input if name not in non_tokened_names
] )
delta = ids[batch_index].shape[0] - batch.shape[0]
if delta > 0:
batch = torch.cat( [ batch, torch.tensor([1] * delta, device=device, dtype=torch.int32) ] )
x_list.append( batch )
ids = torch.stack( x_list )
return ids.to(device=device, dtype=torch.int32)
def calc_loss(
self,
inputs: list,
logits,
logits_aux = None,
quant_levels: list[int] | None = None,
compute_hard_loss = True,
compute_acc = True,
):
loss = {}
stats = {}
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
return input
# handles "tokenizing" an integer
def tokenize_duration( seq_lens, device, dtype=torch.int64 ):
return torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in seq_lens], device=device, dtype=dtype)
k_lo, k_hi = 1, 20
level_loss_factors = self.audio_level_loss_factors
# this could be one array of tuples but can't be assed
loss_targets = []
loss_logits = []
loss_factors = []
loss_names = []
resp_durations = []
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
causal = True
task_type = "tts"
dropout_mask = None
classifier_level = None
for name, input in batch:
if name == "task":
task_type = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
classifier_level = input
# autoregressive, causal
if classifier_level.startswith("AR:"):
causal = True
# nonautoregressive, parallel
elif classifier_level.startswith("NAR:"):
causal = False
it = 0
for name, input in batch:
token = None
ignored = False
# non-tokened tasks
if name in non_tokened_names:
continue
# prom can either be a tensor itself or a list of tensors and strings
if name == "prom":
# expand to list if not a list
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
if logits[batch_index].dim() < 3 and token.dim() >= 2:
token = token[..., 0]
elif name == "resp":
token = input
# mask found, apply it
if dropout_mask is not None:
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
# not a special input, inject as-is
else:
token = input
if not isinstance(token, torch.Tensor):
continue
if token.is_floating_point():
ignored = True
# grab range of our logits for later
seq_len = token.shape[0]
start, end = it, it+seq_len
it += seq_len + 1 # +1 to incorporate the separator
# deduce if a name for a task is an input or output
if name != task_outputs.get(task_type, name):
continue
sequence = token
if token.dim() == 1:
loss_factor = self.loss_factor(name)
if loss_factor == 0.0:
continue
logit = logits[batch_index][start:end]
"""
if self.logit_normalization:
logit = logit_normalization( logit, self.logit_normalization )
"""
if causal or self.predict_causally:
l = self.causal_size
loss_targets.append( token[l:].long() ) # shift the target so that token n...
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
else:
loss_targets.append( token.long() )
loss_logits.append( logit )
loss_factors.append( loss_factor )
loss_names.append( name )
else:
if name == "resp":
resp_durations.append( token.shape[0] )
for level in range( self.n_resp_levels ):
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
continue
logit = logits[batch_index][level][start:end]
"""
if self.logit_normalization:
logit = logit_normalization( logit, self.logit_normalization )
"""
if causal or self.predict_causally:
l = self.causal_size
loss_targets.append( token[l:, level].long() ) # shift the target so that token n...
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
else:
loss_targets.append( token[:, level].long() )
loss_logits.append( logit )
loss_factors.append( level_loss_factors[level] )
loss_names.append( name )
break
loss_target = torch.cat( loss_targets )
loss_logit = torch.cat( loss_logits )
nll = None
acc_k_lo = None
acc_k_hi = None
if compute_hard_loss:
nll = 0
nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
# not my best code
it = 0
weights = 0
bsz = len( loss_targets )
for seq, loss_factor in zip( loss_targets, loss_factors ):
seq_len = seq.shape[0]
start = it
it += seq_len
end = it
nll += nlls[start:end].mean() * loss_factor
weights += loss_factor
# normalize by batch
nll /= bsz
# re-scale by summed weights
nll /= (weights / bsz)
# no this isn't redundant I swear, it'll propagate properly
if compute_acc:
n_vocab = loss_logit.shape[-1]
if n_vocab >= k_lo:
accuracy_metric = MulticlassAccuracy(
n_vocab,
top_k = 1,
average="micro",
multidim_average="global",
ignore_index = -100
).to(loss_logit.device)
acc_k_lo = accuracy_metric( loss_logit, loss_target )
if n_vocab >= k_hi:
accuracy_metric = MulticlassAccuracy(
n_vocab,
top_k = 20,
average="micro",
multidim_average="global",
ignore_index = -100
).to(loss_logit.device)
acc_k_hi = accuracy_metric( loss_logit, loss_target )
# to-do: re-add reporting split losses
if nll is not None:
if 'nll' not in loss:
loss['nll'] = []
loss["nll"] = nll
if acc_k_lo is not None:
acc_k_lo = acc_k_lo.mean()
if f'acc[k={k_lo}]' not in stats:
stats[f'acc[k={k_lo}]'] = []
stats[f"acc[k={k_lo}]"] = acc_k_lo
if acc_k_hi is not None:
acc_k_hi = acc_k_hi.mean()
if f'acc[k={k_hi}]' not in stats:
stats[f'acc[k={k_hi}]'] = []
stats[f"acc[k={k_hi}]"] = acc_k_hi
# check if len logits are provided
if logits_aux is not None:
len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
aux_loss_logit = torch.cat( logits_aux )
if self.len_use_logits:
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).flatten()
loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
else:
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
return LossStats(loss, stats)
def forward(
self,
inputs: list,
quant_levels: list[int] | None = None,
state: dict | list | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
# derive quant levels from inputs if not provided
if quant_levels is None:
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
# inputs don't have quant levels added, pure AR
if len(quant_levels) != len(inputs):
quant_levels = [ 0 for _ in range(len(inputs)) ]
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, mask = list_to_tensor(x_list)
training = self.training
teaching = self.teaching
device = x.device
batch_size = len(x_list)
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input
shape = list(x.shape)
shape[1] = self.l_padding - shape[1] % self.l_padding
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding], dim=1)
# pad mask
shape[2] = 1
padding = torch.zeros(shape[:2], dtype=x.dtype, device=x.device)
mask = torch.cat([mask, padding], dim=1)
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )
causal_levels = [ "phn", "text" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
if self.parallel_attention_mask_dropout > 0:
is_causal = [ True if random.random() < self.parallel_attention_mask_dropout else m for m in is_causal ]
# create special masks
# to-do, create it if mixed (although I expect this model to be purely non-causal)
text_window = 32 if self.use_sliding_attention_mask else 0
audio_window = self.audio_frames_per_second // 2 if self.use_sliding_attention_mask else 0
aux_lens = []
aux_windows = []
# fill aux lens
for batch_index, batch_input in enumerate( inputs ):
lens = [2, 0, 0]
windows = [text_window, audio_window, audio_window]
for name, input in batch_input:
if name in ["phn", "text"]:
lens[0] = input.shape[0] + 1
elif name == "lang":
lens[0] += 2
elif name == "prom":
lens[1] = input.shape[0] + 1
elif name == "tone":
lens[1] += 2
elif name == "resp":
lens[2] = input.shape[0]
aux_lens.append( lens )
aux_windows.append( windows )
if self.use_segmented_attention_mask and not any(is_causal):
mask = self.model._update_segmented_mask( mask, x, aux_lens, window_sizes=aux_windows )
output = self._forward(
inputs=x,
mask=mask,
state=state,
is_causal=is_causal,
position_ids=position_ids,
output_attentions = output_attentions,
)
hidden_states = output.hidden_states
logits = self.audio_decoder( output.logits )
"""
logits = [ logit for logit in output.logits ]
logits_aux = None
grouped_logits = {}
for batch_index in range( batch_size ):
classifier_level = classifier_levels[batch_index]
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
classifier_level = "audio"
if classifier_level not in ["audio", "phn", "text", "len"]:
continue
if classifier_level not in grouped_logits:
grouped_logits[classifier_level] = []
grouped_logits[classifier_level].append(batch_index)
for classifier_level, decoders_indices in grouped_logits.items():
if classifier_level == "audio":
head = self.audio_decoder
elif classifier_level == "phn":
head = self.phn_decoder
elif classifier_level == "text":
head = self.text_decoder
elif classifier_level == "len":
head = self.len_decoder
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
decoders_logits = head( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
"""
# Remove padding
logits = [ logit[..., :l, :] for logit, l in zip(logits, map(len, x_list)) ]
if not training:
loss = None
stats = None
self.loss = None
self.stats = None
# this can all technically be grabbed outside of this forward and manually invoke len_decoder on the last hidden states
tasks = self.get_input( inputs, name="task" )
# grab duration if no resp is provided or len task is requested
if tasks[0] == "len":
# do duration prediction
logits_aux = self.len_decoder( output.logits )
# it's more accurate this way
logits_aux = [ logit[..., -1, :] for logit, aux_len in zip(logits_aux, aux_lens) ]
# reshape logits
if self.len_use_logits:
# get tokens
logits_aux = [ logit.view(5, 10).argmax(dim=-1) for logit in logits_aux ]
# stitch
logits_aux = [ int("".join([ str(t.item()) for t in logit ])) / self.audio_frames_per_second for logit in logits_aux ]
logits = logits_aux
# compute loss if the target is given
else:
# do duration prediction
if self.len_parallel_training:
logits_aux = self.len_decoder( output.logits )
# only keep the input
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :] for logit, aux_len in zip(logits_aux, aux_lens) ]
# reshape logits
if self.len_use_logits:
logits_aux = [ logit.view(5, 10) for logit in logits_aux ]
else:
logits_aux = None
loss, stats = self.calc_loss( inputs=inputs, logits=logits, logits_aux=logits_aux, quant_levels=quant_levels )
# include any additional losses (for example: MoE router)
if output.loss is not None:
loss["aux_loss"] = output.loss
self.loss = loss
self.stats = stats
# rewrap, because we're modifying the logits here
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states)
def sample(
self,
logits: Tensor, # logit scores
prev_list: Tensor | None = None,
len_list: Tensor | None = None,
**sampling_kwargs,
):
# yikes
temperature = sampling_kwargs.get("temperature", 1.0)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
top_k = sampling_kwargs.get("top_k", -100)
top_p = sampling_kwargs.get("top_p", 1.0)
min_p = sampling_kwargs.get("min_p", 0.0)
# repetition penalty parameters
repetition_penalty = sampling_kwargs.get("repetition_penalty", 1.0)
repetition_penalty_decay = sampling_kwargs.get("repetition_penalty_decay", 0.0)
# length penalty parameters
length_penalty = sampling_kwargs.get("length_penalty", 0.0)
# beam sampling parameters
beam_width = sampling_kwargs.get("beam_width", 0)
# mirostat sampling parameters
mirostat = sampling_kwargs.get("mirostat", None)
# DRY sampling parameters
dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0)
dry_base = sampling_kwargs.get("dry_base", 1.75)
dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2)
#
top_no = sampling_kwargs.get("top_no", 1.0)
#
attentions = sampling_kwargs.get("attentions", None)
batch_size = len( logits )
device = logits[0].device
if min_temperature < 0:
min_temperature = temperature
scores = None
entropy = None
causal = False
if prev_list is not None:
seq_lens = [ prev.shape[0] for prev in prev_list ]
elif len_list is not None:
seq_lens = len_list
elif self.causal:
seq_lens = [ self.causal_size for _ in range( batch_size) ]
causal = True
logits = [ logit[..., -l:, :] for l, logit in zip(seq_lens, logits) ]
# perform min_p filtering of our logits
"""
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# do top-no logit processing
if top_no > 0.0:
logits = [ top_no_logits_processing(logit) for logit in logits ]
"""
probabilities = [ F.softmax(logit, dim=-1) for logit in logits ]
scores = [ torch.max(prob, -1)[0] for prob in probabilities ]
if temperature <= 0.0:
res = [ prob.argmax(dim=-1) for prob in probabilities]
else:
res = [ Categorical(logits=logit / temperature).sample() for logit in logits ]
# we only need the scores for NAR demasking, but AR breaks and I cannot be assed to handle it right now
scores = [
torch.gather(prob, 2, tokens.unsqueeze(-1)).squeeze(-1)
for prob, tokens in zip(probabilities, res)
] if not causal else []
return Sampled(res, logits, scores, entropy)