2024-06-09 01:30:15 +00:00
"""
Core model for handling all VALL - E tasks .
This should handle all the " low " level things such as :
* parsing inputs to sequences
* converting sequences to embeddings
* forward pass
* processing loss and returning logits
Additional functionality ( preparing inputs , generating full audio ) should be delegated to classes that inheret the base model
"""
2023-08-02 21:53:35 +00:00
import math
import torch
import torch . nn . functional as F
2024-07-25 00:35:17 +00:00
import random
2023-09-13 02:28:07 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-08-02 21:53:35 +00:00
2024-10-12 02:18:26 +00:00
from time import perf_counter
from collections import namedtuple
2024-05-04 18:07:45 +00:00
from typing import Literal , overload , Optional , Tuple
2023-08-02 21:53:35 +00:00
from functools import partial
from einops import rearrange
from torch import Tensor , einsum , nn
2024-04-17 02:04:48 +00:00
from torch . nn import Embedding
2023-08-02 21:53:35 +00:00
from torch . distributions import Categorical
from torch . nn . utils . rnn import pad_sequence
from torch . utils . checkpoint import checkpoint
from torchmetrics . classification import BinaryAccuracy , MulticlassAccuracy , MulticlassPrecision
2024-06-06 01:30:43 +00:00
from . arch import *
2024-11-10 04:57:34 +00:00
from . . utils import wrapper as ml , clamp
2024-07-30 00:15:07 +00:00
from . . samplers import *
2024-06-30 00:46:11 +00:00
from . . emb . qnt import encode_as_embedding
2024-07-18 21:16:14 +00:00
# yuck, kind of needed
from . . data import get_task_symmap
2024-11-02 16:49:05 +00:00
# these seem more elegant than a dict
2024-11-04 00:31:28 +00:00
Logits = namedtuple ( ' Logits ' , [ ' logits ' , ' state ' , ' aux_loss ' , ' attentions ' , ' hidden_states ' , ' exited_layer ' ] )
2024-11-12 22:41:58 +00:00
Sampled = namedtuple ( ' Sampled ' , [ ' ids ' , ' logits ' , ' scores ' , ' entropy ' ] )
2024-10-31 01:05:45 +00:00
LossStats = namedtuple ( ' LossStats ' , [ ' loss ' , ' stats ' ] )
2024-10-12 02:18:26 +00:00
2024-07-18 19:18:34 +00:00
"""
from . . utils . pattern import DelayedPatternProvider , VALLEPattern
"""
2024-11-08 03:19:14 +00:00
def _dropout_mask ( input , p = None ) :
# cosine scheduling
if p is None :
t = random . random ( )
p = math . cos ( t * math . pi * 0.5 )
2024-11-10 00:04:59 +00:00
seq = [ random . random ( ) < p for _ in range ( input . shape [ 0 ] ) ]
mask = torch . tensor ( seq , dtype = torch . bool , device = input . device )
return mask
2024-11-07 17:32:11 +00:00
2023-08-02 21:53:35 +00:00
def _create_mask ( l , device ) :
""" 1 is valid region and 0 is invalid. """
seq = torch . arange ( max ( l ) , device = device ) . unsqueeze ( 0 ) # (1 t)
stop = torch . tensor ( l , device = device ) . unsqueeze ( 1 ) # (b 1)
return ( seq < stop ) . float ( ) # (b t)
def _join ( x : tuple [ Tensor ] , sep : Tensor ) :
"""
Args :
x : ( k t d )
sep : ( d )
"""
ret = x [ 0 ]
for i in range ( 1 , len ( x ) ) :
ret = torch . cat ( ( ret , sep [ None ] , x [ i ] ) , dim = 0 )
return ret
def list_to_tensor ( x_list : list [ Tensor ] , pattern = " t b c -> b t c " ) :
"""
Args :
x_list : [ ( t d ) ]
Returns :
x : ( ? ? ? )
m : ( ? ? ? ) , same as x
"""
l = list ( map ( len , x_list ) )
x = rearrange ( pad_sequence ( x_list ) , pattern )
m = _create_mask ( l , x_list [ 0 ] . device )
2024-11-03 02:00:21 +00:00
"""
2023-08-02 21:53:35 +00:00
m = m . t ( ) . unsqueeze ( - 1 ) # (t b 1)
m = rearrange ( m , pattern )
2024-11-03 02:00:21 +00:00
"""
m = m . to ( x ) . int ( )
2023-08-02 21:53:35 +00:00
return x , m
2024-08-05 00:56:21 +00:00
def _interleave_sequence_reshape ( input : list [ torch . Tensor ] , dim = - 1 ) :
shape = ( input [ 0 ] . shape [ 0 ] * len ( input ) , input [ 0 ] . shape [ dim ] )
return torch . concat ( [ i . t ( ) for i in input ] ) . t ( ) . reshape ( shape )
def _interleave_sequence_flatten ( input : list [ torch . Tensor ] ) :
return torch . concat ( [ i . t ( ) for i in input ] ) . t ( ) . flatten ( )
2023-09-09 01:30:54 +00:00
# automagically parses a batch-list and returns it as a list
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
class Embedding ( nn . Embedding ) :
def forward ( self , x_list : list [ Tensor ] ) - > list [ Tensor ] :
if len ( x_list ) == 0 :
return [ ]
return super ( ) . forward ( torch . cat ( x_list ) ) . split ( [ * map ( len , x_list ) ] )
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
2024-04-29 23:24:05 +00:00
# Deprecated implementation
2023-09-13 18:19:11 +00:00
class MultiEmbedding ( nn . Module ) :
2023-09-08 06:03:24 +00:00
def __init__ ( self , max_n_levels , n_tokens , token_dim , monolithic = False ) :
2023-09-16 05:26:13 +00:00
super ( ) . __init__ ( )
2023-09-08 20:36:26 +00:00
self . monolithic = monolithic
2023-08-02 21:53:35 +00:00
self . max_n_levels = max_n_levels
self . n_tokens = n_tokens
2023-09-08 20:36:26 +00:00
self . weight = nn . Parameter ( torch . randn ( max_n_levels , n_tokens , token_dim ) )
2023-08-02 21:53:35 +00:00
2024-08-04 01:23:36 +00:00
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resps_emb)
2023-09-07 22:08:38 +00:00
# I imagine this is an oversight in the NAR.
2024-06-08 01:46:22 +00:00
def forward ( self , x_list : list [ Tensor ] , quant_level : int | list [ int ] | Tensor | None = None ) - > list [ Tensor ] :
2023-08-02 21:53:35 +00:00
if len ( x_list ) == 0 :
return [ ]
2023-09-08 20:36:26 +00:00
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
2024-08-04 01:23:36 +00:00
# the NAR cannot share RVQ-bin level 0 with the AR for the resps_emb
2023-09-08 06:03:24 +00:00
if self . monolithic :
2024-06-08 01:46:22 +00:00
w = self . weight [ : 1 ] if quant_level is None or quant_level == 0 else self . weight [ 1 : ]
2023-09-08 06:03:24 +00:00
else :
w = self . weight
2023-08-02 21:53:35 +00:00
2023-09-11 19:13:42 +00:00
padded_x_list = [ ]
2023-08-02 21:53:35 +00:00
2023-09-08 20:36:26 +00:00
for i , xi in enumerate ( x_list ) :
2023-08-02 21:53:35 +00:00
xi = F . one_hot ( xi . to ( torch . int64 ) , num_classes = self . n_tokens ) # t l' k
2023-09-11 19:13:42 +00:00
wi = w . shape [ 0 ] - xi . shape [ 1 ]
xi = F . pad ( xi , ( 0 , 0 , 0 , wi ) ) # t l k
2023-08-02 21:53:35 +00:00
padded_x_list . append ( xi . to ( w ) )
2023-09-11 19:13:42 +00:00
x = torch . cat ( padded_x_list ) # n l k
x = einsum ( " l k d, n l k -> n d " , w , x )
2023-08-02 21:53:35 +00:00
x_list = x . split ( [ * map ( len , x_list ) ] )
2023-09-07 21:48:02 +00:00
2023-09-11 19:13:42 +00:00
return x_list
2023-09-07 21:48:02 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
2024-06-09 01:30:15 +00:00
# _Old, to preserve compat with previous models.
2024-06-06 23:52:41 +00:00
class AudioEmbedding_Old ( nn . Module ) :
2024-04-29 23:24:05 +00:00
def __init__ (
self ,
l_tokens : int , # list of number of tokens (needed because AR resps includes stop token)
token_dim : int , # dimensionality of the embedding
levels : int | None = None , # number of RVQ-bins (I don't remember the specifics)
) :
2023-09-07 14:14:03 +00:00
super ( ) . __init__ ( )
2024-04-29 23:24:05 +00:00
# array of embeddings
2024-07-16 00:59:48 +00:00
# proms are [0, resp_levels]
2024-04-29 23:24:05 +00:00
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
2023-09-21 00:10:59 +00:00
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
2024-04-29 23:24:05 +00:00
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
2024-08-04 03:10:21 +00:00
self . weight = nn . ParameterList ( [ nn . Parameter ( torch . tensor ( [ 1 ] ) ) for i in range ( levels ) ] ) if levels is not None else None
2024-06-06 23:52:41 +00:00
2024-06-09 01:30:15 +00:00
def forward ( self , xi : Tensor , quant_level : Tensor | None = None ) - > Tensor :
2024-04-17 02:04:48 +00:00
# prom
2024-06-09 01:30:15 +00:00
if quant_level is None and xi . shape [ - 1 ] > 1 :
2024-06-07 00:41:26 +00:00
x = sum ( [ self . embeddings [ k ] ( xi [ : , k ] ) * ( self . weight [ k ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
2024-06-09 01:30:15 +00:00
# prom / AR resp
elif quant_level is None or quant_level == 0 :
2024-06-08 20:42:02 +00:00
x = self . embeddings [ 0 ] ( xi if xi . dim ( ) == 1 else xi [ : , 0 ] )
2024-04-17 02:04:48 +00:00
# NAR resp
else :
2024-06-07 00:41:26 +00:00
x = sum ( [ self . embeddings [ k + 1 ] ( xi [ : , k ] ) * ( self . weight [ k + 1 ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
2024-04-17 02:04:48 +00:00
return x
2023-09-07 21:48:02 +00:00
2024-06-09 01:30:15 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
# Mostly to handle some oversights and errors during testing
2024-06-06 23:52:41 +00:00
class AudioEmbedding ( nn . Module ) :
def __init__ (
self ,
2024-06-12 03:28:59 +00:00
l_tokens : list [ int ] , # list of number of tokens (needed because AR resps includes stop token)
2024-06-06 23:52:41 +00:00
token_dim : int , # dimensionality of the embedding
2024-06-30 02:46:35 +00:00
sums : bool = True , # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
external_mode : str | None = None , # "exclusive" | "inclusive", whether to include the original audio backend's embeddings
2024-08-05 00:56:21 +00:00
capabilities : list [ str ] | None = None , # helper shit
2024-06-06 23:52:41 +00:00
) :
super ( ) . __init__ ( )
# array of embeddings
2024-07-16 00:59:48 +00:00
# proms are [0, resp_levels]
2024-06-06 23:52:41 +00:00
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
2024-07-30 00:15:07 +00:00
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
2024-06-06 23:52:41 +00:00
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
2024-06-30 02:46:35 +00:00
# further experimentation is needed to see if this actually is useful
2024-06-06 23:52:41 +00:00
self . sums = sums
2024-06-30 02:46:35 +00:00
self . external_mode = external_mode
2024-08-05 00:56:21 +00:00
self . capabilities = capabilities
2024-06-30 02:46:35 +00:00
# set initial weights to zero
if self . external_mode == " inclusive " :
for i , embedding in enumerate ( self . embeddings ) :
embedding . weight = torch . nn . Parameter ( torch . zeros ( embedding . weight . shape ) )
2024-08-04 01:23:36 +00:00
def external_embeddings ( self , input : Tensor , quant_level : int | None = None ) - > Tensor :
if quant_level is None :
quant_level = 0 if input . dim ( ) == 1 else input . shape [ - 1 ] - 1
2024-06-30 02:46:35 +00:00
# for AR, trim any stop tokens
has_stop_token = False
2024-06-30 03:14:35 +00:00
# this block apparently doesn't work
"""
2024-06-30 02:46:35 +00:00
if quant_level == 0 :
stop_token = self . embeddings [ 0 ] . weight . shape [ 0 ] - 1
stop_token_indices = ( input == stop_token ) . nonzero ( )
has_stop_token = len ( stop_token_indices ) > 0
if has_stop_token :
input = input [ : stop_token_indices . min ( ) . item ( ) ]
2024-06-30 03:14:35 +00:00
"""
has_stop_token = False
if quant_level == 0 :
stop_token = self . embeddings [ 0 ] . weight . shape [ 0 ] - 1
has_stop_token = input [ - 1 ] == stop_token
if has_stop_token :
input = input [ : - 1 ]
2024-06-30 02:46:35 +00:00
# get external embedding
2024-06-30 04:42:30 +00:00
embedding = encode_as_embedding ( input , quant_level , sums = self . sums ) . to ( device = input . device , dtype = self . embeddings [ quant_level ] . weight . dtype )
2024-06-30 02:46:35 +00:00
# resize if necessary (in case the external embeddings do not match our model dim)
embedding = ml . resize_weight ( embedding , self . embeddings [ quant_level ] . weight . shape [ - 1 ] , dim = - 1 , random = False )
# reintroduce stop token
if has_stop_token :
2024-08-04 03:10:21 +00:00
stop_token = self . internal_forward ( torch . tensor ( [ stop_token ] ) . to ( device = input . device , dtype = torch . int16 ) , 0 )
2024-06-30 02:46:35 +00:00
embedding = torch . concat ( [ embedding , stop_token ] )
return embedding
2024-09-08 03:13:49 +00:00
def internal_forward ( self , xi : Tensor , offset : int | None = None , quant_level : int | None = None , sums = None ) - > Tensor :
2024-08-05 00:56:21 +00:00
if offset is None :
# prom
if self . capabilities is None :
offset = 0
elif " nar " not in self . capabilities :
offset = 0
elif quant_level > 0 :
offset = 1
2024-09-08 03:13:49 +00:00
if sums is None :
sums = self . sums
2024-08-04 01:23:36 +00:00
if quant_level is None :
quant_level = 0 if xi . dim ( ) == 1 else xi . shape [ - 1 ] - 1
2024-06-07 00:41:26 +00:00
if self . sums and quant_level > 0 :
2024-06-06 23:52:41 +00:00
x = sum ( [ self . embeddings [ k + offset ] ( xi [ : , k ] ) for k in range ( quant_level ) ] )
else :
2024-06-08 20:42:02 +00:00
k = quant_level
2024-06-07 00:41:26 +00:00
x = self . embeddings [ k + offset ] ( xi if xi . dim ( ) == 1 else xi [ : , k ] )
2024-06-06 23:52:41 +00:00
return x
2024-09-08 03:13:49 +00:00
def forward ( self , xi : Tensor , offset : int | None = None , quant_level : int | None = None , sums = None ) - > Tensor :
x = self . internal_forward ( xi , offset = offset , quant_level = quant_level , sums = sums ) if self . external_mode != " exclusive " or xi . shape [ 0 ] == 0 else None
2024-06-30 00:46:11 +00:00
2024-06-30 02:46:35 +00:00
if self . external_mode and xi . shape [ 0 ] > 0 :
2024-08-04 01:23:36 +00:00
external_embeddings = self . external_embeddings ( xi , quant_level = quant_level )
2024-06-30 02:46:35 +00:00
if self . external_mode == " exclusive " :
return external_embeddings
x + = external_embeddings
2024-06-30 00:46:11 +00:00
2024-06-30 02:46:35 +00:00
return x
2024-06-30 00:46:11 +00:00
2024-11-09 04:46:26 +00:00
# time-step embedding
# for the NAR-len, since it probably most likely requires encoding the timestep
class TimeEmbedding ( nn . Module ) :
def __init__ (
self ,
d_model
) :
super ( ) . __init__ ( )
self . emb = SinusoidalEmbedding ( d_model )
self . mlp = nn . Sequential (
nn . Linear ( d_model , d_model * 4 ) ,
nn . SiLU ( ) ,
nn . Linear ( d_model * 4 , d_model ) ,
)
def forward ( self , t ) :
t = self . emb ( t )
t = self . mlp ( t )
return t
2024-06-12 03:28:59 +00:00
# per-level classification
2024-09-06 01:43:20 +00:00
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
class Classifiers ( nn . Module ) :
2024-06-12 03:28:59 +00:00
def __init__ (
self ,
l_tokens : list [ int ] , # list of number of tokens (needed because AR resps includes stop token)
token_dim : int , # dimensionality of the embedding
) :
super ( ) . __init__ ( )
self . proj = nn . ModuleList ( [ nn . Linear ( token_dim , n_tokens ) for n_tokens in l_tokens ] )
def forward ( self , xi : Tensor , levels : list [ int ] ) - > Tensor :
2024-07-19 20:33:31 +00:00
dtype = xi . dtype
device = xi . device
xi = [ self . proj [ l ] ( x ) for x , l in zip ( xi , levels ) ]
# pad if needed
2024-08-02 01:12:06 +00:00
# to-do: validate that this causes ZERO issues
2024-08-01 01:35:09 +00:00
max_size = max ( [ x . shape [ - 1 ] for x in xi ] )
xi = [
#x if l == 0 else
x if x . shape [ - 1 ] == max_size else
2024-09-06 16:41:41 +00:00
torch . cat ( [ x , torch . full ( ( x . shape [ 0 ] , max_size - x . shape [ - 1 ] ) , - float ( " inf " ) , device = device , dtype = dtype ) ] , dim = - 1 )
2024-08-01 01:35:09 +00:00
for x , l in zip ( xi , levels )
]
2024-07-19 20:33:31 +00:00
return torch . stack ( xi )
2024-06-12 03:28:59 +00:00
class Metrics ( nn . Module ) :
def __init__ (
self ,
l_tokens : int | list [ int ] ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = - 100
) :
super ( ) . __init__ ( )
self . accuracy = nn . ModuleList ( [ MulticlassAccuracy (
n_tokens ,
top_k = top_k ,
average = average ,
multidim_average = multidim_average ,
ignore_index = ignore_index ,
) for n_tokens in l_tokens ] )
self . precision = nn . ModuleList ( [ MulticlassPrecision (
n_tokens ,
top_k = top_k ,
average = average ,
multidim_average = multidim_average ,
ignore_index = ignore_index ,
) for n_tokens in l_tokens ] )
def calc_accuracy ( self , inputs , targets , quant_levels ) :
2024-07-19 20:33:31 +00:00
return sum ( [ self . accuracy [ l ] ( input [ : , : self . accuracy [ l ] . num_classes ] , target ) for target , input , l in zip ( targets , inputs , quant_levels ) ] ) / len ( inputs )
2024-06-12 03:28:59 +00:00
def calc_precision ( self , inputs , targets , quant_levels ) :
2024-07-19 20:33:31 +00:00
return sum ( [ self . precision [ l ] ( input [ : , : self . precision [ l ] . num_classes ] , target ) for target , input , l in zip ( targets , inputs , quant_levels ) ] ) / len ( inputs )
2024-06-12 03:28:59 +00:00
def __call__ ( self , * args , * * kwargs ) :
return dict (
acc = self . calc_accuracy ( * args , * * kwargs ) ,
)
2023-08-02 21:53:35 +00:00
class Base ( nn . Module ) :
2024-05-19 16:23:56 +00:00
def loss_factor ( self , k ) :
2024-06-06 14:48:43 +00:00
if self . config is None :
2024-05-19 16:23:56 +00:00
return 1.0
2024-06-06 14:48:43 +00:00
return self . config . loss_factors [ k ] if k in self . config . loss_factors else 1.0
2024-05-19 16:23:56 +00:00
2024-08-04 00:51:00 +00:00
def _prune ( self , l : Tensor , stop = None ) :
if stop is None :
stop = self . stop_token
indices = ( l == stop ) . nonzero ( )
if len ( indices ) == 0 :
return l
return l [ : indices . min ( ) . item ( ) ]
2024-07-18 19:18:34 +00:00
# these probably need to live in an interleaved model, as pattern-ing is targeted for a sole AR model
"""
def codes_to_pattern ( self , codes ) :
# expand if not batched
if codes . dim ( ) == 2 :
codes = codes . unsqueeze ( 0 )
# [batch, timestep, rvq level] (B, T, K) => [batch, rvq level, timestep] (B, K, T)
codes = codes . permute ( 0 , 2 , 1 )
B , K , T = codes . shape
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
pattern = self . pattern_provider . get_pattern ( T )
sequence_codes , sequence_indexes , sequence_mask = pattern . build_pattern_sequence (
codes . contiguous ( ) , self . stop_token , keep_only_valid_steps = False ,
)
# (B, K, T) => (B, T, K)
return sequence_codes . permute ( 0 , 2 , 1 )
def logits_from_pattern ( self , logits , pattern ) :
logits = logits . permute ( 0 , 3 , 1 , 2 ) # [B, card, K, S]
logits , logits_indexes , logits_mask = pattern . revert_pattern_logits (
logits , float ( ' nan ' ) , keep_only_valid_steps = False
)
logits = logits . permute ( 0 , 2 , 3 , 1 ) # [B, K, T, card]
logits_mask = logits_mask [ None , : , : ] . expand ( B , - 1 , - 1 ) # [K, T] -> [B, K, T]
return logits , logits_mask
"""
2023-08-02 21:53:35 +00:00
def __init__ (
self ,
2024-06-06 00:50:06 +00:00
n_text_tokens : int = 256 ,
n_audio_tokens : int = 1024 ,
2023-08-02 21:53:35 +00:00
d_model : int = 512 ,
n_heads : int = 8 ,
n_layers : int = 12 ,
p_dropout : float = 0.1 ,
2023-09-07 00:33:39 +00:00
2024-04-09 01:14:51 +00:00
n_experts : int = 1 ,
l_padding : int = 0 ,
2023-12-21 00:45:58 +00:00
2024-02-01 03:48:36 +00:00
training = True ,
2024-08-27 00:33:51 +00:00
attention = None ,
2023-09-04 03:46:08 +00:00
config = None ,
2023-08-02 21:53:35 +00:00
) :
super ( ) . __init__ ( )
2024-02-01 03:48:36 +00:00
self . training = training
2024-06-06 14:48:43 +00:00
self . config = config
2023-09-04 03:46:08 +00:00
2024-06-06 00:50:06 +00:00
self . n_text_tokens = n_text_tokens
self . n_audio_tokens = n_audio_tokens
2023-08-02 21:53:35 +00:00
self . d_model = d_model
self . n_heads = n_heads
self . n_layers = n_layers
2023-12-21 00:45:58 +00:00
self . n_experts = n_experts
2024-04-09 01:14:51 +00:00
self . l_padding = l_padding
2023-08-02 21:53:35 +00:00
2024-08-04 00:51:00 +00:00
self . ignore_index = - 100
self . n_resp_levels = self . config . resp_levels if self . config else n_resp_levels
self . n_max_levels = self . config . max_levels if self . config else n_resp_levels
self . capabilities = self . config . capabilities if self . config else [ " ar " , " nar " ]
self . gradient_checkpointing = self . config . gradient_checkpointing if self . config is not None else True
self . stop_token = self . n_audio_tokens # id 1024
self . causal = " ar " in self . capabilities or " len " in self . capabilities
self . version = self . config . version if self . config is not None else 5
self . causal_size = self . config . experimental . causal_size if self . config is not None else ( 1 if " ar " in self . capabilities else 0 )
2024-06-30 16:00:12 +00:00
2024-08-04 00:51:00 +00:00
self . arch_type = self . config . arch_type if self . config is not None else " llama "
2024-06-12 04:59:28 +00:00
2024-06-15 00:42:17 +00:00
# check if requested arch is unavailable
if self . arch_type in ERROR_ARCHES :
raise ERROR_ARCHES [ self . arch_type ]
2024-08-04 00:51:00 +00:00
2024-08-27 00:33:51 +00:00
if not attention :
attention = self . config . attention if self . config is not None else " auto "
attention_backend = attention
2024-08-03 13:40:39 +00:00
audio_embedding_sums = self . config . experimental . audio_embedding_sums if self . config is not None else False
split_classifiers = self . config . experimental . split_classifiers if self . config is not None else False
tie_classifier_to_embedding = self . config . experimental . tie_classifier_to_embedding if self . config is not None else False
audio_embedding_mode = self . config . experimental . audio_embedding_mode if self . config is not None else " "
unified_position_ids = self . config . experimental . unified_position_ids if self . config is not None else True
2024-08-05 00:56:21 +00:00
interleave = self . config . experimental . interleave if self . config is not None else False
2024-10-31 18:24:48 +00:00
2024-10-31 01:05:45 +00:00
layerskip = self . config . experimental . layerskip if self . config is not None else False
2024-10-31 18:24:48 +00:00
layerskip_r = self . config . experimental . layerskip_r if self . config is not None else 2
layerskip_p_max = self . config . experimental . layerskip_p_max if self . config is not None else 0.1
layerskip_e_scale = self . config . experimental . layerskip_e_scale if self . config is not None else 0.1
2024-06-15 00:42:17 +00:00
2024-08-04 00:51:00 +00:00
n_tasks = self . config . tasks if self . config is not None else 8
n_langs = self . config . langs if self . config is not None else 2
n_tones = self . config . tones if self . config is not None else 1
2024-08-05 00:56:21 +00:00
# pure AR
if " nar " not in self . capabilities :
n_resp_tokens = n_audio_tokens + 1
l_tokens = [ n_resp_tokens ] * self . n_resp_levels
2024-11-07 17:32:11 +00:00
# AR+NAR model / NAR-len model
else :
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self . causal_size > 0 else 0 )
l_tokens = [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 )
2024-08-03 03:25:49 +00:00
2024-07-17 00:52:41 +00:00
self . unified_position_ids = unified_position_ids
2024-08-05 00:56:21 +00:00
self . interleave = interleave
2024-10-31 01:05:45 +00:00
self . layerskip = layerskip
2024-11-02 16:49:05 +00:00
self . special_tasks = [ " len " , " stt " ]
2024-11-10 00:12:54 +00:00
self . inject_timestep_embedding = False # results in bad output
2024-06-12 03:28:59 +00:00
2024-06-06 00:50:06 +00:00
self . text_emb = Embedding ( n_text_tokens , d_model )
2023-10-22 14:01:47 +00:00
self . langs_emb = None
2024-04-16 00:54:32 +00:00
self . tones_emb = None
2023-10-22 14:01:47 +00:00
self . tasks_emb = None
2024-06-08 20:42:02 +00:00
self . rvq_l_emb = None
self . len_emb = None
2024-07-25 00:35:17 +00:00
# it would be nicer for these to be a token or live inside an embedding
self . sep = nn . Parameter ( torch . randn ( d_model ) )
2024-11-07 15:10:18 +00:00
self . dropout_token = nn . Parameter ( torch . randn ( d_model ) )
2023-09-07 00:33:39 +00:00
2023-09-11 19:13:42 +00:00
if self . version == 1 : # legacy
2024-08-04 00:51:00 +00:00
n_audio_tokens + = ( n_tasks - 1 ) # old models have the task tokens in the prom
2024-07-16 00:59:48 +00:00
self . proms_emb = MultiEmbedding ( self . n_resp_levels , n_audio_tokens , d_model )
2023-09-11 19:13:42 +00:00
self . resps_emb = MultiEmbedding ( self . n_resp_levels , n_resp_tokens , d_model , monolithic = self . monolithic )
2024-06-06 23:52:41 +00:00
elif self . version < 5 :
2023-09-21 00:10:59 +00:00
# [1024] * 8
2024-06-06 23:52:41 +00:00
self . proms_emb = AudioEmbedding_Old (
2024-07-16 00:59:48 +00:00
[ n_audio_tokens ] * self . n_resp_levels , d_model ,
levels = self . n_resp_levels if self . version > 3 else None ,
2024-04-29 23:24:05 +00:00
)
2024-06-06 00:50:06 +00:00
# [1024 + STOP] + [1024] * 8
2024-06-06 23:52:41 +00:00
self . resps_emb = AudioEmbedding_Old (
2024-06-12 04:59:28 +00:00
l_tokens , d_model ,
2024-04-29 23:24:05 +00:00
levels = self . n_resp_levels if self . version > 3 else None ,
2024-06-06 23:52:41 +00:00
)
else :
self . proms_emb = AudioEmbedding (
2024-07-16 00:59:48 +00:00
[ n_audio_tokens ] * self . n_resp_levels , d_model ,
2024-06-12 03:28:59 +00:00
sums = audio_embedding_sums ,
2024-06-30 16:00:12 +00:00
external_mode = audio_embedding_mode ,
2024-08-05 00:56:21 +00:00
capabilities = None ,
2024-06-06 23:52:41 +00:00
)
self . resps_emb = AudioEmbedding (
2024-06-12 04:59:28 +00:00
l_tokens , d_model ,
2024-06-12 03:28:59 +00:00
sums = audio_embedding_sums ,
2024-06-30 16:00:12 +00:00
external_mode = audio_embedding_mode ,
2024-08-05 00:56:21 +00:00
capabilities = self . capabilities ,
2024-04-29 23:24:05 +00:00
)
2023-09-21 00:10:59 +00:00
2023-10-12 01:38:40 +00:00
if self . version > = 3 :
2024-08-04 00:51:00 +00:00
self . langs_emb = Embedding ( n_langs , d_model ) if n_langs > 0 else None
self . tasks_emb = Embedding ( n_tasks , d_model ) if n_tasks > 0 else None
2024-06-06 00:50:06 +00:00
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
2024-04-16 00:54:32 +00:00
if self . version > = 4 :
2024-08-04 00:51:00 +00:00
self . tones_emb = Embedding ( n_tones , d_model ) if n_tones > 0 else None
2023-08-02 21:53:35 +00:00
2024-06-06 00:50:06 +00:00
# mamba requires this if a model does both AR and NAR tasks
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
# this ***might*** let me also unify the proms_emb and resps_embedding
2024-06-05 04:23:31 +00:00
if self . version > = 5 :
2024-11-07 15:10:18 +00:00
# "len" RVQ level-0 gets an additional token
2024-11-10 18:19:48 +00:00
self . rvq_l_emb = Embedding ( self . n_resp_levels , d_model )
2024-06-08 20:42:02 +00:00
# experimental NAR-only mode
2024-11-10 18:19:48 +00:00
self . len_emb = Embedding ( 11 , d_model )
self . time_emb = TimeEmbedding ( d_model )
2024-06-05 04:23:31 +00:00
2024-08-04 00:51:00 +00:00
if attention_backend == " auto " :
2024-11-05 00:00:33 +00:00
attention_backend = " sdpa "
"""
2024-08-09 00:38:55 +00:00
if AVAILABLE_ATTENTIONS :
attention_backend = AVAILABLE_ATTENTIONS [ 0 ]
2024-08-07 01:42:39 +00:00
else :
2024-10-12 16:27:55 +00:00
attention_backend = " default "
2024-11-05 00:00:33 +00:00
"""
2024-07-30 00:53:37 +00:00
2024-08-04 00:51:00 +00:00
hf_attention = attention_backend
2024-08-27 00:13:34 +00:00
HF_ATTENTIONS = [ " eager " , " sdpa " , " flash_attention_2 " ]
2024-05-11 22:14:05 +00:00
2024-08-27 00:13:34 +00:00
if attention_backend not in HF_ATTENTIONS :
2024-05-11 22:14:05 +00:00
hf_attention = None
2024-08-04 00:51:00 +00:00
if attention_backend not in AVAILABLE_ATTENTIONS :
raise ValueError ( f " Requesting attention ` { attention_backend } ` but is not available. Currently available: { AVAILABLE_ATTENTIONS } " )
2024-05-10 01:28:20 +00:00
2024-11-07 15:10:18 +00:00
# override any requested padding size
2024-08-27 00:13:34 +00:00
if attention_backend == " flash_attn_v100 " :
self . l_padding = 32
2024-10-13 16:02:24 +00:00
elif attention_backend == " fused_attn " :
2024-08-27 00:13:34 +00:00
self . l_padding = 128
2023-08-02 21:53:35 +00:00
if self . arch_type == " transformer " :
self . sin_emb = SinusoidalEmbedding ( d_model )
self . blocks = nn . ModuleList ( [ TransformerBlock (
d_model = d_model ,
n_heads = n_heads ,
2024-02-01 03:48:36 +00:00
p_dropout = p_dropout if training else 0.0 ,
2023-08-19 01:58:07 +00:00
causal = self . causal ,
2024-06-30 16:00:12 +00:00
norm_type = " ln " , # adaln
2023-08-02 21:53:35 +00:00
n_levels = self . n_resp_levels ,
) for _ in range ( n_layers ) ] )
2024-05-10 01:28:20 +00:00
elif self . arch_type in [ " mistral " , " mixtral " ] :
2024-02-01 03:48:36 +00:00
if n_experts < = 1 :
self . model = MistralModel ( MistralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2024-02-01 03:48:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-30 15:37:33 +00:00
num_key_value_heads = self . config . experimental . kv_heads if self . config is not None and self . config . experimental . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2024-02-01 03:48:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-30 15:37:33 +00:00
num_key_value_heads = self . config . experimental . kv_heads if self . config is not None and self . config . experimental . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
2024-08-27 00:13:34 +00:00
if attention_backend not in HF_ATTENTIONS :
2024-08-05 03:03:22 +00:00
self . model = ml . replace_attention ( self . model , klass = MixtralAttention_Adapted , target = MixtralAttention , mode = attention_backend )
2024-05-10 04:25:44 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2023-12-23 01:27:36 +00:00
elif self . arch_type == " llama " :
2024-11-10 18:19:48 +00:00
LlamaClass = LlamaModel_Adapted # if (self.layerskip or "len" in self.capabilities) else LlamaModel
2024-11-10 00:04:59 +00:00
2023-12-23 01:27:36 +00:00
if n_experts < = 1 :
2024-10-31 01:05:45 +00:00
self . model = LlamaClass ( LlamaConfig (
2023-12-23 01:27:36 +00:00
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2023-12-23 01:27:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
2024-11-10 00:04:59 +00:00
# replace with desired attention
2024-08-27 00:13:34 +00:00
if attention_backend not in HF_ATTENTIONS :
2024-08-05 03:03:22 +00:00
self . model = ml . replace_attention ( self . model , klass = LlamaAttention_Adapted , target = LlamaAttention , mode = attention_backend )
2023-12-23 01:27:36 +00:00
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2023-12-23 01:27:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
2024-08-27 00:13:34 +00:00
if attention_backend not in HF_ATTENTIONS :
2024-08-05 03:03:22 +00:00
self . model = ml . replace_attention ( self . model , klass = MixtralAttention_Adapted , target = MixtralAttention , mode = attention_backend )
2024-05-10 04:15:52 +00:00
2024-10-31 18:24:48 +00:00
if self . layerskip :
self . model . layer_dropout_p = layerskip_p_max
self . model . early_exit_scale = layerskip_e_scale
self . model . early_exit_r = layerskip_r
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-12-26 03:20:32 +00:00
kwargs = dict (
2023-12-23 01:27:36 +00:00
vocab_size = n_resp_tokens ,
2023-08-02 21:53:35 +00:00
decoder_embed_dim = d_model ,
2023-10-05 21:39:46 +00:00
decoder_value_embed_dim = d_model * 2 ,
2023-08-02 21:53:35 +00:00
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
2024-02-01 03:48:36 +00:00
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2023-10-05 21:39:46 +00:00
activation_fn = " gelu " ,
2024-05-10 01:28:20 +00:00
use_layernorm = self . version < 3 ,
use_biases = self . version < 3 ,
use_glu = self . version > = 3 ,
2023-08-02 21:53:35 +00:00
2024-06-07 01:51:31 +00:00
chunkwise_recurrent = self . causal and self . causal_size > 0 ,
recurrent_chunkwise_size = self . causal_size if self . causal else 0 ,
2023-08-02 21:53:35 +00:00
no_output_layer = True ,
decoder_normalize_before = True ,
2023-09-21 00:10:59 +00:00
2024-06-07 01:51:31 +00:00
rotary_embedding_base = 10000
2023-12-26 03:20:32 +00:00
)
if n_experts > 1 :
kwargs . update ( dict (
use_xmoe = True ,
moe_freq = 1 ,
moe_expert_count = n_experts ,
moe_gating_use_fp32 = False ,
) )
2023-12-21 00:45:58 +00:00
2023-12-26 03:20:32 +00:00
self . model = RetNetDecoder ( RetNetConfig ( * * kwargs ) )
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
kwargs = dict (
vocab_size = n_resp_tokens ,
decoder_embed_dim = d_model ,
decoder_value_embed_dim = d_model * 2 ,
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2024-04-09 01:14:51 +00:00
activation_fn = " gelu " ,
use_glu = False , # self.version >= 3,
2024-06-07 01:51:31 +00:00
recurrent_chunk_size = self . causal_size if self . causal else 0 ,
2024-04-09 01:14:51 +00:00
decoder_normalize_before = True ,
deepnorm = False ,
subln = True ,
)
self . model = RetNetDecoder_HF ( RetNetConfig_HF ( * * kwargs ) )
2024-05-12 13:22:39 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-12 13:22:39 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
self . model = BitNetTransformer (
num_tokens = n_resp_tokens ,
dim = d_model ,
depth = n_layers ,
heads = n_heads ,
ff_mult = 4 ,
2024-06-04 02:28:49 +00:00
gradient_checkpointing = self . gradient_checkpointing ,
2024-03-01 02:29:17 +00:00
)
2024-06-05 03:41:22 +00:00
elif self . arch_type in [ " mamba " , " mamba2 " ] :
self . model = MambaMixelModel (
vocab_size = n_resp_tokens ,
d_model = d_model ,
2024-08-10 02:15:01 +00:00
n_layer = n_layers * 2 ,
d_intermediate = 0 , #d_model*2,
ssm_cfg = { " layer " : " Mamba2 " , " use_mem_eff_path " : True } if self . arch_type == " mamba2 " else { } ,
2024-06-05 03:41:22 +00:00
rms_norm = True ,
fused_add_norm = True ,
2024-08-10 02:15:01 +00:00
residual_in_fp32 = True ,
2024-06-05 03:41:22 +00:00
#attn_layer_idx=attn_layer_idx,
#attn_cfg=attn_cfg,
#initializer_cfg=initializer_cfg,
)
self . model . gradient_checkpointing = self . gradient_checkpointing
2024-06-15 00:42:17 +00:00
elif self . arch_type in [ " mamba2-hf " ] :
self . model = Mamba2Model_HF ( Mamba2Config_HF (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
expand = 4 ,
num_hidden_layers = n_layers ,
is_encoder_decoder = False ,
is_decoder = True ,
use_triton_kernels = False , # the entire reason is to NOT use triton (because V100s hate it)
2024-08-10 02:15:01 +00:00
residual_in_fp32 = True , # breaks for AMP inference
2024-06-15 00:42:17 +00:00
) )
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-06-12 03:28:59 +00:00
elif self . arch_type == " mmfreelm " :
self . model = HGRNBitModel ( HGRNBitConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_heads = n_heads ,
#hidden_act="gelu",
#is_encoder_decoder=False,
#is_decoder=True,
attn_mode = hf_attention ,
#gradient_checkpointing=self.gradient_checkpointing,
) )
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-03-01 02:29:17 +00:00
else :
raise RuntimeError ( f ' Unknown arch specified: { self . arch_type } ' )
2023-08-04 01:26:36 +00:00
2024-06-15 00:42:17 +00:00
if hasattr ( self . model , " embeddings " ) :
del self . model . embeddings
2024-05-10 01:28:20 +00:00
2024-06-12 03:28:59 +00:00
if not split_classifiers :
self . classifier = nn . Linear ( d_model , n_resp_tokens )
self . classifiers = None
self . accuracy_metric = MulticlassAccuracy (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
2023-08-02 21:53:35 +00:00
2024-06-12 03:28:59 +00:00
self . precision_metric = MulticlassPrecision (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
self . metrics = None
else :
self . classifier = None
2024-09-06 01:43:20 +00:00
self . classifiers = Classifiers ( l_tokens + [ n_text_tokens ] , d_model )
2024-06-12 03:28:59 +00:00
self . accuracy_metric = None
self . precision_metric = None
2024-09-06 01:43:20 +00:00
self . metrics = Metrics ( l_tokens + [ n_text_tokens ] )
2023-08-02 21:53:35 +00:00
2024-07-19 20:33:31 +00:00
"""
if tie_classifier_to_embedding :
for i , proj in enumerate ( self . classifiers . proj ) :
self . classifiers . proj [ i ] . weight = self . resps_emb . embeddings [ i ] . weight
"""
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
def _forward (
2023-08-02 21:53:35 +00:00
self ,
2024-04-16 00:54:32 +00:00
inputs ,
mask = None ,
2024-07-17 00:52:41 +00:00
position_ids = None ,
2024-10-12 02:18:26 +00:00
2024-04-16 00:54:32 +00:00
state = None ,
2024-11-02 02:30:06 +00:00
2024-11-02 16:49:05 +00:00
layer_skip_lambda = None ,
2024-11-10 00:04:59 +00:00
timesteps = None ,
2024-11-02 02:30:06 +00:00
2024-10-12 02:18:26 +00:00
output_attentions = False ,
2024-10-31 01:05:45 +00:00
output_hidden_states = False ,
2023-08-02 21:53:35 +00:00
) :
2024-04-16 00:54:32 +00:00
x = inputs
2024-11-03 02:00:21 +00:00
m = mask #.squeeze(-1).int()
2024-10-12 02:18:26 +00:00
2023-12-23 01:27:36 +00:00
aux_loss = None
2024-10-12 02:18:26 +00:00
attentions = None
2024-10-31 01:05:45 +00:00
hidden_states = None
2024-07-16 23:23:13 +00:00
2023-12-23 01:27:36 +00:00
# HF transformer derived model
2024-04-16 00:54:32 +00:00
if self . arch_type in [ " llama " , " mistral " , " mixtral " ] :
2023-12-23 01:27:36 +00:00
kwargs = dict (
2024-11-03 02:00:21 +00:00
#attention_mask=m,
2023-12-23 01:27:36 +00:00
inputs_embeds = x ,
2024-02-01 03:48:36 +00:00
past_key_values = state ,
2024-07-17 00:52:41 +00:00
position_ids = position_ids ,
2024-11-02 16:49:05 +00:00
use_cache = False , # not self.training,
2024-10-12 02:18:26 +00:00
output_attentions = output_attentions ,
2024-10-31 01:05:45 +00:00
output_hidden_states = output_hidden_states ,
2024-10-12 02:18:26 +00:00
return_dict = True ,
2023-12-23 01:27:36 +00:00
)
2024-11-02 02:30:06 +00:00
2024-06-06 23:52:41 +00:00
if self . n_experts > 1 and self . training :
2023-12-23 01:27:36 +00:00
kwargs [ " output_router_logits " ] = True
2023-09-02 01:58:29 +00:00
2024-11-02 16:49:05 +00:00
if self . layerskip and layer_skip_lambda is not None :
kwargs [ " layer_skip_lambda " ] = layer_skip_lambda
2024-11-02 02:30:06 +00:00
2024-11-10 00:04:59 +00:00
if " len " in self . capabilities and timesteps is not None :
kwargs [ " timesteps " ] = timesteps
2024-10-12 02:18:26 +00:00
output = self . model ( * * kwargs )
x = output [ " last_hidden_state " ]
2023-12-23 01:27:36 +00:00
2024-10-06 03:53:53 +00:00
# to-do: figure out why KV caching doesn't work
#if not self.training:
2024-02-01 03:48:36 +00:00
if state is not None :
2024-10-12 02:18:26 +00:00
state = output [ " past_key_values " ]
if output_attentions :
attentions = output [ " attentions " ]
2024-02-01 03:48:36 +00:00
2024-10-31 01:05:45 +00:00
if output_hidden_states :
hidden_states = output [ " hidden_states " ]
2024-06-06 23:52:41 +00:00
if self . n_experts > 1 and self . training :
2024-10-12 02:18:26 +00:00
router_logits = output [ " aux_loss " ]
2023-12-23 01:27:36 +00:00
aux_loss = self . model . config . router_aux_loss_coef * load_balancing_loss_func ( router_logits , self . model . config . num_local_experts , self . model . config . num_experts_per_tok )
elif self . arch_type == " transformer " :
2023-09-12 21:04:45 +00:00
# ensures we specify a quant_level for the transformer implementation's AdaLN
2023-09-07 14:14:03 +00:00
l = torch . zeros ( ( batch_size , ) , dtype = torch . int32 ) if quant_levels is None else quant_levels
l = l . to ( device )
2023-09-12 21:04:45 +00:00
# inject position information
x = self . sin_emb . add_pe ( x )
# pass our inputs through the transformer
2023-08-02 21:53:35 +00:00
for block in self . blocks :
2024-04-16 00:54:32 +00:00
x = block ( x , m , l )
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-09-12 21:04:45 +00:00
# pass our inputs through the RetNet
2023-12-23 01:27:36 +00:00
x , _ = self . model ( x , incremental_state = state , token_embeddings = x , features_only = True )
2023-12-26 03:20:32 +00:00
if _ is not None and " l_aux " in _ and self . n_experts > 1 :
2023-12-23 01:27:36 +00:00
aux_loss = torch . sum ( torch . stack ( [ t for t in _ [ " l_aux " ] if t is not None ] ) ) * 0.001
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
2024-04-14 18:12:50 +00:00
first = state is None or len ( state ) == 0
2024-04-09 01:14:51 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2024-04-14 18:12:50 +00:00
inputs_embeds = x if first else x [ : , - 1 , : ] . unsqueeze ( 1 ) ,
past_key_values = None if first else state ,
use_cache = True ,
forward_impl = ' parallel ' if first else ' recurrent ' ,
return_dict = True ,
2024-04-09 01:14:51 +00:00
)
2024-04-14 18:12:50 +00:00
out = self . model ( * * kwargs )
x = out . last_hidden_state
2024-04-09 01:14:51 +00:00
if state is not None :
2024-04-14 18:12:50 +00:00
state = out . past_key_values
2024-06-05 03:41:22 +00:00
elif self . arch_type in [ " mamba " , " mamba2 " ] :
x = self . model ( hidden_states = x )
2024-06-15 00:42:17 +00:00
elif self . arch_type == " mamba2-hf " :
first = state is None or len ( state ) == 0
kwargs = dict (
inputs_embeds = x ,
cache_params = state ,
return_dict = True ,
)
out = self . model ( * * kwargs )
x = out . last_hidden_state
if state is not None :
state = out . cache_params
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
x = self . model ( x )
2024-06-12 03:28:59 +00:00
elif self . arch_type == " mmfreelm " :
x = self . model (
attention_mask = m ,
inputs_embeds = x ,
)
x = x [ 0 ]
2024-04-14 18:12:50 +00:00
2024-10-31 01:05:45 +00:00
# process it into a format that I like
if output_hidden_states :
2024-10-31 18:24:48 +00:00
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
2024-11-01 17:50:37 +00:00
hidden_states = [ state for state in hidden_states [ 1 : ] ]
# apply normalization to these states (to-do: check if this matters)
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self . n_layers - 1 else self . model . norm ( output . hidden_states [ i ] ) for i , state in enumerate ( hidden_states ) ]
2024-10-31 01:05:45 +00:00
2024-11-04 00:31:28 +00:00
return Logits ( x , state , aux_loss , attentions , hidden_states , None )
2024-04-16 00:54:32 +00:00
2024-07-18 19:18:34 +00:00
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
2024-04-17 02:04:48 +00:00
def inputs (
2024-04-16 00:54:32 +00:00
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
2024-04-17 02:04:48 +00:00
2024-04-16 00:54:32 +00:00
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
2024-06-08 20:42:02 +00:00
len_list : list [ Tensor ] | None = None ,
task_list : list [ str ] | None = None ,
2024-11-09 04:46:26 +00:00
time_list : list [ Tensor ] | None = None ,
2024-06-05 04:23:31 +00:00
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None
2024-04-16 00:54:32 +00:00
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
2024-04-17 02:04:48 +00:00
inputs = [ [ ] for _ in range ( batch_size ) ]
for i in range ( batch_size ) :
2024-06-05 04:23:31 +00:00
quant_level = quant_levels [ i ] if quant_levels is not None else 0
2024-06-08 20:42:02 +00:00
task_type = task_list [ i ] if task_list is not None else " tts "
2024-11-10 18:19:48 +00:00
timestep = time_list [ i ] if time_list is not None else None
2024-06-08 20:42:02 +00:00
2024-07-18 19:18:34 +00:00
# insert task type as a string
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " task " , task_type ) )
2024-07-23 00:36:07 +00:00
# to-do: maybe not split the below blocks up
# might be beneficial in the event I need to use a difference sequence, such as STT tasks
2024-07-18 19:18:34 +00:00
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
2024-07-18 23:46:45 +00:00
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
2024-11-02 16:49:05 +00:00
if f ' < { task_type } > ' in get_task_symmap ( ) and task_type not in self . special_tasks :
2024-07-18 19:18:34 +00:00
# insert the text prompt
2024-07-23 00:36:07 +00:00
if text_list is not None and text_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
2024-07-18 19:18:34 +00:00
# insert lang token if we're trained for it
2024-07-23 00:36:07 +00:00
if " lang " in self . capabilities and lang_list is not None and lang_list [ i ] is not None :
2024-07-18 19:18:34 +00:00
inputs [ i ] . append ( ( " lang " , lang_list [ i ] ) )
# insert RVQ level guidance token if the model is versioned for it
2024-08-05 00:56:21 +00:00
if self . rvq_l_emb is not None and not self . interleave :
2024-08-04 03:10:21 +00:00
inputs [ i ] . append ( ( " quant_level " , torch . tensor ( [ quant_level ] , device = device , dtype = torch . int16 ) ) )
2024-07-18 19:18:34 +00:00
# insert input audio prompt
2024-07-23 00:36:07 +00:00
if proms_list is not None and proms_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
2024-07-18 19:18:34 +00:00
# insert tone token if we're trained for it
2024-07-23 00:36:07 +00:00
if " tone " in self . capabilities and tone_list is not None and tone_list [ i ] is not None :
2024-07-18 19:18:34 +00:00
inputs [ i ] . append ( ( " tone " , tone_list [ i ] ) )
2024-11-13 15:07:10 +00:00
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
"""
2024-11-10 00:12:54 +00:00
# insert timestep token
2024-11-10 18:19:48 +00:00
if timestep is not None :
2024-11-10 00:12:54 +00:00
# store timestep information
inputs [ i ] . append ( ( " timestep " , torch . tensor ( [ timestep ] , device = device , dtype = self . time_emb . mlp [ 0 ] . weight . dtype ) ) )
2024-11-13 15:07:10 +00:00
"""
2024-07-18 19:18:34 +00:00
# insert the current output response
2024-07-23 00:36:07 +00:00
if resps_list is not None and resps_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " resp " , resps_list [ i ] ) )
2024-11-10 01:40:02 +00:00
2024-11-13 15:07:10 +00:00
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
2024-11-12 22:41:58 +00:00
if timestep is not None and self . training :
2024-11-13 15:07:10 +00:00
# a paper said to use a fixed masking ratio for training
"""
# cosine scheduled timestep => masking ratio
p = math . cos ( timestep * math . pi * 0.5 )
"""
p = 0.8
dropout_mask = _dropout_mask ( resps_list [ i ] , p )
2024-11-10 01:40:02 +00:00
inputs [ i ] . append ( ( " dropout_mask " , dropout_mask ) )
2024-07-18 19:18:34 +00:00
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
2024-06-08 20:42:02 +00:00
elif task_type == " len " :
# throw an error so we don't silently train without this
if self . len_emb is None :
raise Exception ( f " Requesting task ` { task_type } ` but corresponding embedding is not defined. " )
2024-07-18 19:18:34 +00:00
# insert the text prompt
2024-07-23 00:36:07 +00:00
if text_list is not None and text_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
2024-07-18 19:18:34 +00:00
# insert lang token if we're trained for it
2024-07-23 00:36:07 +00:00
if " lang " in self . capabilities and lang_list is not None and lang_list [ i ] is not None :
2024-07-18 19:18:34 +00:00
inputs [ i ] . append ( ( " lang " , lang_list [ i ] ) )
2024-06-08 20:42:02 +00:00
# technically will always be level 0 but for the sake of keeing the input formatting coherent...
if self . rvq_l_emb is not None :
2024-11-10 18:19:48 +00:00
inputs [ i ] . append ( ( " quant_level " , torch . tensor ( [ quant_level ] , device = device , dtype = torch . int16 ) ) )
2024-07-18 19:18:34 +00:00
# insert input audio prompt
2024-07-23 00:36:07 +00:00
if proms_list is not None and proms_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
2024-07-18 19:18:34 +00:00
# insert tone token if we're trained for it
2024-07-23 00:36:07 +00:00
if " tone " in self . capabilities and tone_list is not None and tone_list [ i ] is not None :
2024-07-18 19:18:34 +00:00
inputs [ i ] . append ( ( " tone " , tone_list [ i ] ) )
2024-06-08 20:42:02 +00:00
2024-07-18 19:18:34 +00:00
# insert output length tokens (if it exists)
2024-07-23 00:36:07 +00:00
if len_list is not None and len_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " len " , len_list [ i ] ) )
# "encode" length to tokens for 0-9 + stop
2024-07-23 00:36:07 +00:00
elif resps_list is not None and resps_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
# yes this could be encoded better
2024-08-04 03:10:21 +00:00
inputs [ i ] . append ( ( " len " , torch . tensor ( [ 0 ] + [ int ( i ) for i in str ( resps_list [ i ] . shape [ 0 ] ) ] + [ 10 ] , device = device , dtype = torch . int16 ) ) )
2024-09-06 01:43:20 +00:00
# Speech-to-Text prediction task
# Sequence: <resp><sep><rvq lvl><sep><text>
elif task_type == " stt " :
# insert the input response
if resps_list is not None and resps_list [ i ] is not None :
inputs [ i ] . append ( ( " resp " , resps_list [ i ] ) )
# insert lang token if we're trained for it
if " lang " in self . capabilities and lang_list is not None and lang_list [ i ] is not None :
inputs [ i ] . append ( ( " lang " , lang_list [ i ] ) )
# insert RVQ level guidance token if the model is versioned for it
if self . rvq_l_emb is not None and not self . interleave :
inputs [ i ] . append ( ( " quant_level " , torch . tensor ( [ quant_level ] , device = device , dtype = torch . int16 ) ) )
# insert the output text prompt
if text_list is not None and text_list [ i ] is not None :
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
2024-07-18 22:16:32 +00:00
else :
raise Exception ( f ' Unrecognized task: { task_type } ' )
2024-04-17 02:04:48 +00:00
return inputs
def inputs_to_embeddings (
self ,
inputs : list ,
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None
2024-04-17 02:04:48 +00:00
) :
2024-07-18 21:16:14 +00:00
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding ( input , quant_level ) :
2024-07-19 04:25:32 +00:00
if isinstance ( input , str ) :
2024-08-04 03:10:21 +00:00
return self . tasks_emb ( torch . tensor ( [ get_task_symmap ( ) [ f ' < { input } > ' ] ] , device = device , dtype = torch . int16 ) )
2024-07-18 21:16:14 +00:00
# get RVQ level 0, or up to targetted RVQ level inference
if self . version < = 4 :
2024-08-04 01:23:36 +00:00
return self . proms_emb (
input if quant_level == 0 else input [ : , : quant_level ]
)
2024-07-18 21:16:14 +00:00
2024-08-04 01:23:36 +00:00
return self . proms_emb (
input if input . dim ( ) == 1 else input [ : , : 1 if quant_level == 0 else quant_level ] ,
quant_level = 0 if quant_level == 0 else quant_level - 1 , # input is one below the target quant level
offset = 0 ,
)
2024-07-18 21:16:14 +00:00
2024-07-25 00:35:17 +00:00
# yuck
token_dropout_rate = self . config . experimental . token_dropout_rate if self . config else 0.0
token_dropout_rvq_levels = self . config . experimental . token_dropout_rvq_levels if self . config else None
if self . dropout_token is None or not self . training :
token_dropout_rate = 0.0
if not token_dropout_rvq_levels :
token_dropout_rvq_levels = [ 1 , self . resp_levels ]
2024-09-06 01:43:20 +00:00
summed_embeddings_task = [ " stt " ]
2024-04-17 02:04:48 +00:00
x_list = [ ]
2024-05-29 00:29:54 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-04-17 02:04:48 +00:00
batch = [ ]
2024-06-05 04:23:31 +00:00
quant_level = quant_levels [ batch_index ] if quant_levels is not None else 0
2024-07-19 04:25:32 +00:00
2024-07-18 21:16:14 +00:00
task_type = " tts "
2024-08-04 01:23:36 +00:00
input_prom = None
2024-11-07 17:32:11 +00:00
dropout_mask = None
2024-11-09 04:46:26 +00:00
timestep = None
2024-11-07 17:32:11 +00:00
# pre-iterate
for name , input in batch_input :
if name == " dropout_mask " :
dropout_mask = input
2024-05-29 00:29:54 +00:00
for name , input in batch_input :
2024-06-08 20:42:02 +00:00
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
2024-04-17 02:04:48 +00:00
embedding = None
2024-06-30 00:46:11 +00:00
# is already an embedding
2024-06-08 20:42:02 +00:00
if name == " task " :
# noop
# *maybe* inject a token for specifying task type
2024-07-18 21:16:14 +00:00
task_type = input
2024-06-08 20:42:02 +00:00
continue
elif name == " text " :
2024-04-17 02:04:48 +00:00
embedding = self . text_emb ( input )
2024-07-19 04:25:32 +00:00
device = embedding . device
2024-06-08 20:42:02 +00:00
elif name == " quant_level " and self . rvq_l_emb is not None :
embedding = self . rvq_l_emb ( input )
2024-06-05 15:30:04 +00:00
elif name == " lang " and self . langs_emb is not None :
2024-04-17 02:04:48 +00:00
embedding = self . langs_emb ( input )
elif name == " prom " :
2024-07-18 21:16:14 +00:00
proms = [ input ] if isinstance ( input , torch . Tensor ) else input
2024-11-10 00:04:59 +00:00
"""
if proms is None :
continue
"""
2024-11-07 17:32:11 +00:00
# to-do: probably insert separators if task requires it?
2024-07-19 04:25:32 +00:00
embedding = torch . cat ( [ prompt_input_to_embedding ( input , quant_level ) for input in proms if input is not None ] )
2024-06-05 15:30:04 +00:00
elif name == " tone " and self . tones_emb is not None :
2024-04-17 02:04:48 +00:00
embedding = self . tones_emb ( input )
elif name == " resp " :
2024-08-05 00:56:21 +00:00
if self . interleave :
embeddings = [ self . resps_emb (
input [ : , : l + 1 ] ,
offset = 0 ,
quant_level = l
) for l in range ( input . shape [ - 1 ] ) ]
embedding = _interleave_sequence_reshape ( embeddings )
2024-11-08 03:19:14 +00:00
# if training NAR-len RVQ level 0
2024-11-10 18:19:48 +00:00
elif dropout_mask is not None :
2024-11-08 03:19:14 +00:00
embedding = self . resps_emb (
2024-11-10 00:04:59 +00:00
# if masked use masked token, else original token
2024-11-08 03:19:14 +00:00
torch . where ( dropout_mask , self . stop_token , input if input . dim ( ) == 1 else input [ : , 0 ] ) ,
2024-11-07 01:17:12 +00:00
offset = 0 ,
2024-11-08 03:19:14 +00:00
quant_level = 0 ,
2024-11-07 01:17:12 +00:00
)
2024-09-06 01:43:20 +00:00
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task :
2024-09-08 03:13:49 +00:00
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
2024-09-06 01:43:20 +00:00
embedding = sum ( [ self . resps_emb (
input [ : , : l + 1 ] ,
offset = 0 if l == 0 else 1 , # or maybe set to 1
2024-09-08 03:13:49 +00:00
quant_level = l ,
sums = False
2024-09-06 01:43:20 +00:00
) for l in range ( input . shape [ - 1 ] - 1 ) ] )
2024-06-08 20:42:02 +00:00
else :
# get RVQ level 0, or up to targetted RVQ level inference
2024-06-09 01:30:15 +00:00
if self . version < = 4 :
2024-07-18 19:18:34 +00:00
embedding = self . resps_emb (
input if quant_level == 0 else input [ : , : quant_level ] ,
quant_level
)
2024-06-09 01:30:15 +00:00
else :
2024-08-05 00:56:21 +00:00
offset = 0
2024-11-07 15:10:18 +00:00
if " nar " not in self . capabilities :
2024-08-05 00:56:21 +00:00
offset = 0
elif quant_level > 0 :
offset = 1
2024-06-09 01:30:15 +00:00
embedding = self . resps_emb (
input if input . dim ( ) == 1 or quant_level == 0 else input [ : , : quant_level ] ,
2024-08-05 00:56:21 +00:00
offset = offset ,
2024-08-04 01:23:36 +00:00
quant_level = 0 if quant_level == 0 else quant_level - 1 , # input is one below the target quant level
2024-06-09 01:30:15 +00:00
)
2024-07-25 00:35:17 +00:00
# apply token dropout
if token_dropout_rate > 0.0 and ( token_dropout_rvq_levels [ 0 ] < = quant_level and quant_level < = token_dropout_rvq_levels [ 1 ] ) :
steps = embedding . shape [ 0 ] - ( 1 if quant_level == 0 else 0 ) # do not mess with stop token
for i in range ( steps ) :
if random . random ( ) > token_dropout_rate :
continue
embedding [ i ] = self . dropout_token
2024-11-10 00:04:59 +00:00
elif name == " timestep " and self . time_emb is not None :
embedding = self . time_emb ( input )
2024-06-08 20:42:02 +00:00
elif name == " len " and self . len_emb is not None :
embedding = self . len_emb ( input )
2024-04-17 02:04:48 +00:00
else :
2024-06-08 20:42:02 +00:00
# should probably raise an exception so things aren't processed silently
2024-04-17 02:04:48 +00:00
continue
2024-11-10 00:04:59 +00:00
2024-04-17 02:04:48 +00:00
batch . append ( embedding )
2024-07-18 22:16:32 +00:00
2024-04-17 02:04:48 +00:00
x_list . append ( _join ( batch , self . sep ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return x_list
2024-11-10 00:04:59 +00:00
# get an attribute from a given input list
def get_input (
self ,
inputs ,
name ,
at = None ,
) :
2024-11-10 18:19:48 +00:00
find_all = at is None
res = [ ] if at is None else None
2024-11-10 00:04:59 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-11-10 18:19:48 +00:00
if not find_all and batch_index != at :
2024-11-10 00:04:59 +00:00
continue
for n , input in batch_input :
2024-11-10 18:19:48 +00:00
if n != name :
continue
if not find_all :
2024-11-10 00:04:59 +00:00
return input
2024-11-10 18:19:48 +00:00
res . append ( input )
return res
2024-11-10 00:04:59 +00:00
2024-07-18 19:18:34 +00:00
# creates position ids from a given input list
# if not unified_position_ids, then each input segment will have its own sequence
2024-07-17 00:52:41 +00:00
def inputs_to_position_ids (
self ,
inputs : list ,
mask : Tensor ,
) :
2024-08-04 03:10:21 +00:00
device = mask . device
2024-07-17 00:52:41 +00:00
# shamelessly grabbed from modeling_llama.py
ids = mask . long ( ) . cumsum ( - 1 ) - 1
ids . masked_fill_ ( mask == 0 , 1 )
2024-08-01 01:35:09 +00:00
2024-07-17 00:52:41 +00:00
# there's a better way
if not self . unified_position_ids :
x_list = [ ]
2024-07-19 04:25:32 +00:00
def get_input_token_length ( name , input ) :
# task token
if isinstance ( input , str ) :
return 1
2024-11-07 17:32:11 +00:00
# a mask
2024-11-10 00:04:59 +00:00
if name in [ " dropout_mask " ] :
2024-11-07 17:32:11 +00:00
return 0
2024-07-19 04:25:32 +00:00
# list of tokens
if not isinstance ( input , torch . Tensor ) :
2024-09-06 16:41:41 +00:00
return sum ( [ i . shape [ 0 ] for i in input if isinstance ( i , torch . Tensor ) ] ) + 1
2024-07-19 04:25:32 +00:00
2024-08-05 00:56:21 +00:00
# interleaved model
if self . interleave and name == " resp " :
return input . shape [ 0 ] * input . shape [ 1 ]
2024-08-03 03:25:49 +00:00
# ending input will not have a separator later
2024-08-01 01:35:09 +00:00
return input . shape [ 0 ] + ( 0 if name in [ " resp " , " len " ] else 1 )
2024-07-19 04:25:32 +00:00
2024-07-17 00:52:41 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-07-19 04:25:32 +00:00
batch = torch . cat ( [
2024-08-04 03:10:21 +00:00
torch . tensor ( [ * range ( get_input_token_length ( name , input ) ) ] , device = device , dtype = torch . int32 )
2024-07-19 04:25:32 +00:00
for name , input in batch_input if name != " task "
] )
2024-07-17 00:52:41 +00:00
delta = ids [ batch_index ] . shape [ 0 ] - batch . shape [ 0 ]
if delta > 0 :
2024-08-04 13:18:57 +00:00
batch = torch . cat ( [ batch , torch . tensor ( [ 1 ] * delta , device = device , dtype = torch . int32 ) ] )
2024-07-17 00:52:41 +00:00
x_list . append ( batch )
ids = torch . stack ( x_list )
2024-08-04 03:10:21 +00:00
return ids . to ( device = device , dtype = torch . int32 )
2024-07-17 00:52:41 +00:00
2024-05-29 00:29:54 +00:00
def calc_loss (
2024-04-17 02:04:48 +00:00
self ,
inputs : list ,
2024-05-29 00:29:54 +00:00
logits ,
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None ,
2024-04-17 02:04:48 +00:00
) :
2024-10-31 01:05:45 +00:00
loss = dict ( ce = dict ( ) )
stats = dict ( acc = dict ( ) )
2024-07-19 04:25:32 +00:00
device = logits [ 0 ] . device
2024-11-10 01:04:13 +00:00
batch_size = len ( logits )
2024-09-06 01:43:20 +00:00
summed_embeddings_task = [ " stt " ]
2024-11-10 01:04:13 +00:00
#classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
tasks = [ self . get_input ( inputs , " task " , at = i ) for i in range ( batch_size ) ]
classifier_quant_levels = quant_levels if self . classifier is not None else [ - 1 if tasks [ i ] in self . special_tasks else l for i , l in enumerate ( quant_levels ) ]
2024-07-19 20:33:31 +00:00
2024-07-18 21:16:14 +00:00
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token ( input , quant_level ) :
2024-07-19 04:25:32 +00:00
if isinstance ( input , str ) :
2024-08-04 13:18:57 +00:00
return torch . tensor ( [ get_task_symmap ( ) [ f ' < { input } > ' ] ] , device = device , dtype = torch . int16 )
2024-07-18 21:16:14 +00:00
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self . version < 4 or ( self . version > = 5 and self . config and self . config . experimental . audio_embedding_sums ) :
return torch . full_like ( input [ . . . , 0 ] , self . ignore_index )
return input if input . dim ( ) == 1 else input [ : , quant_level ]
2024-05-29 00:29:54 +00:00
# old, "naive" way, no loss factoring
2024-06-06 14:48:43 +00:00
if not self . config . loss_factors :
2024-05-29 00:29:54 +00:00
target_list = [ ]
2024-06-08 20:42:02 +00:00
task_list = [ ]
2024-11-13 15:43:50 +00:00
is_causal = [ ]
2024-06-08 01:34:36 +00:00
2024-06-05 04:23:31 +00:00
for batch_index , batch in enumerate ( inputs ) :
2024-06-08 01:34:36 +00:00
quant_level = quant_levels [ batch_index ]
2024-05-29 00:29:54 +00:00
target = [ ]
2024-09-06 01:43:20 +00:00
task_type = " tts "
2024-11-07 17:32:11 +00:00
2024-11-13 15:43:50 +00:00
causal = False
2024-11-07 17:32:11 +00:00
dropout_mask = None
for name , input in batch :
if name == " dropout_mask " :
dropout_mask = input
2024-05-29 00:29:54 +00:00
for name , input in batch :
2024-06-08 20:42:02 +00:00
if name == " task " :
2024-09-06 01:43:20 +00:00
task_type = input
2024-06-08 20:42:02 +00:00
task_list . append ( input )
elif name == " prom " :
2024-07-18 21:16:14 +00:00
proms = [ input ] if isinstance ( input , torch . Tensor ) else input
2024-07-19 04:25:32 +00:00
target . append ( torch . cat ( [ prompt_input_to_token ( input , quant_level ) for input in proms if input is not None ] ) )
2024-06-06 23:52:41 +00:00
elif name == " resp " :
2024-11-13 15:43:50 +00:00
causal = ( quant_level == 0 and " ar " in self . capabilities ) or ( " nar " not in self . capabilities ) or ( task_type in [ " len " , " stt " ] )
2024-11-07 17:32:11 +00:00
# mask found, apply it
if dropout_mask is not None :
2024-11-10 00:04:59 +00:00
# if mask use original token, else ignore
2024-11-13 15:43:50 +00:00
causal = False
2024-11-08 03:19:14 +00:00
target . append ( torch . where ( dropout_mask , input if input . dim ( ) == 1 else input [ : , 0 ] , self . ignore_index ) )
2024-11-07 17:32:11 +00:00
elif self . interleave :
2024-08-05 00:56:21 +00:00
target . append ( _interleave_sequence_flatten ( [ input [ : , l ] for l in range ( input . shape [ - 1 ] ) ] ) )
2024-09-06 01:43:20 +00:00
elif task_type in summed_embeddings_task :
target . append ( torch . full_like ( input [ . . . , 0 ] , self . ignore_index ) )
2024-08-05 00:56:21 +00:00
else :
target . append ( input if input . dim ( ) == 1 else input [ : , quant_level ] )
2024-11-10 00:04:59 +00:00
elif name == " timestep " :
target . append ( torch . tensor ( [ self . ignore_index ] , device = input . device ) )
2024-06-08 20:42:02 +00:00
elif name in [ " text " , " quant_level " , " lang " , " tone " , " len " ] :
2024-05-29 00:29:54 +00:00
target . append ( input )
2024-11-13 15:43:50 +00:00
is_causal . append ( causal )
2024-05-29 00:29:54 +00:00
target_list . append ( _join ( target , torch . tensor ( self . ignore_index , device = target [ - 1 ] . device ) ) )
2024-06-06 02:02:05 +00:00
batch_size = len ( target_list )
2024-11-13 15:43:50 +00:00
# modify only causal sequences so it can properly behave like a transformer
2024-06-06 02:02:05 +00:00
for i in range ( batch_size ) :
2024-08-03 13:40:39 +00:00
quant_level = quant_levels [ i ]
task_name = task_list [ i ]
2024-11-13 15:43:50 +00:00
causal = is_causal [ i ]
2024-04-17 02:04:48 +00:00
2024-08-03 13:40:39 +00:00
if causal :
l = self . causal_size
logits [ i ] = logits [ i ] [ . . . , : - l , : ] # shift the target so that token n...
target_list [ i ] = target_list [ i ] [ . . . , l : ] # predicts token n + 1
2024-04-17 02:04:48 +00:00
2024-11-10 01:40:02 +00:00
for batch_index , target in enumerate ( target_list ) :
logit = logits [ batch_index ]
max_classes = logit . shape [ - 1 ]
max_token = torch . max ( target ) . item ( )
if max_token > max_classes :
task = self . get_input ( inputs , " task " , at = batch_index )
print ( batch_index , task , target , max_token , max_classes , inputs [ batch_index ] )
2024-06-06 00:50:06 +00:00
# see comments for the split-loss calc cross_entropy call
if False :
target = torch . cat ( target_list )
inputs = torch . cat ( logits )
2024-10-31 01:05:45 +00:00
loss = dict (
2024-06-06 00:50:06 +00:00
nll = F . cross_entropy ( inputs , target , ignore_index = self . ignore_index )
)
2024-10-31 01:05:45 +00:00
stats = self . metrics ( inputs , targets , classifier_quant_levels ) if self . metrics is not None else dict (
2024-06-06 00:50:06 +00:00
acc = self . accuracy_metric ( inputs , target ) ,
# precision = self.precision_metric( inputs, target ),
)
else :
2024-11-10 01:04:13 +00:00
# nll being natural log likelihood :)))) (I don't know why this completely escaped me originally with thinking it meant something else)
2024-10-31 01:05:45 +00:00
loss = dict (
2024-06-08 01:34:36 +00:00
nll = sum ( [ F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) for targets , inputs in zip ( target_list , logits ) ] ) / batch_size
2024-06-06 00:50:06 +00:00
)
2024-10-31 01:05:45 +00:00
stats = self . metrics ( logits , target_list , classifier_quant_levels ) if self . metrics is not None else dict (
2024-06-06 02:02:05 +00:00
acc = sum ( [ self . accuracy_metric ( inputs , targets ) for targets , inputs in zip ( target_list , logits ) ] ) / batch_size
2024-06-06 00:50:06 +00:00
)
2024-04-17 02:04:48 +00:00
2024-10-31 01:05:45 +00:00
return LossStats ( loss , stats )
2024-05-19 16:23:56 +00:00
2024-06-06 00:50:06 +00:00
"""
# considerations:
# * split losses does not maintain the entire sequence
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
2024-07-30 00:15:07 +00:00
# + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
2024-06-06 00:50:06 +00:00
"""
2024-05-29 00:29:54 +00:00
info = { }
2024-06-06 02:02:05 +00:00
batch_size = len ( inputs )
2024-05-29 00:29:54 +00:00
for i , batch in enumerate ( inputs ) :
2024-06-08 01:34:36 +00:00
quant_level = quant_levels [ i ]
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
it = 0
2024-08-03 13:40:39 +00:00
task_name = None
2024-05-29 00:29:54 +00:00
for name , input in batch :
# do not use resp
if name == " resp " :
2024-09-06 01:43:20 +00:00
if self . interleave :
input = _interleave_sequence_flatten ( [ input [ : , l ] for l in range ( input . shape [ - 1 ] ) ] )
elif task_type in summed_embeddings_task :
input = torch . full_like ( input [ . . . , 0 ] , self . ignore_index )
else :
input = input if input . dim ( ) == 1 else input [ : , quant_level ]
2024-05-29 00:29:54 +00:00
# select prom level
2024-06-08 01:34:36 +00:00
elif name == " prom " :
2024-07-18 21:16:14 +00:00
proms = [ input ] if isinstance ( input , torch . Tensor ) else input
input = torch . cat ( [ prompt_input_to_token ( input , quant_level ) for input in proms ] )
2024-06-08 20:42:02 +00:00
# meta-input, no corresponding token at the moment
elif name == " task " :
2024-08-03 13:40:39 +00:00
task_name = input
2024-06-08 20:42:02 +00:00
continue
2024-05-29 00:29:54 +00:00
seq_len = input . shape [ 0 ]
2024-06-05 04:23:31 +00:00
2024-05-29 00:29:54 +00:00
logit = logits [ i ] [ it : it + seq_len ]
it + = seq_len + 1 # +1 to incorporate the separator
2024-11-10 18:19:48 +00:00
causal = ( quant_level == 0 and " ar " in self . capabilities ) or ( " nar " not in self . capabilities ) or ( task_name in [ " len " , " stt " ] )
2024-08-03 13:40:39 +00:00
2024-05-29 00:29:54 +00:00
# for the AR, shift sequence so that it predicts the next token
2024-07-30 00:15:07 +00:00
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
2024-08-03 13:40:39 +00:00
if causal and seq_len > 1 :
2024-06-07 01:51:31 +00:00
l = self . causal_size
logit = logit [ . . . , : - l , : ]
input = input [ . . . , l : ] # shift sequence to the right by one (or causal chunk size)
2024-05-29 00:29:54 +00:00
if name not in info :
info [ name ] = {
" targets " : [ ] ,
" logits " : [ ] ,
}
2024-06-06 00:50:06 +00:00
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
info [ name ] [ " targets " ] . append ( input . long ( ) )
info [ name ] [ " logits " ] . append ( logit )
2024-05-29 00:29:54 +00:00
for name , batch in info . items ( ) :
loss_factor = self . loss_factor ( name )
2024-08-03 13:40:39 +00:00
2024-06-08 20:42:02 +00:00
if name not in [ " text " , " prom " , " resp " , " len " ] :
2024-06-05 04:48:51 +00:00
continue
2024-05-29 00:29:54 +00:00
if loss_factor == 0.0 :
continue
2024-05-27 13:43:00 +00:00
2024-06-06 00:50:06 +00:00
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
# to-do: set this to a var
if False :
targets = torch . cat ( batch [ " targets " ] ) . long ( )
inputs = torch . cat ( batch [ " logits " ] )
2024-10-31 01:05:45 +00:00
loss [ name ] = F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) * loss_factor
stats [ " acc " ] [ name ] = self . accuracy_metric ( inputs , targets )
2024-06-06 00:50:06 +00:00
# probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else :
2024-10-31 01:05:45 +00:00
loss [ name ] = sum ( [ F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) * loss_factor for targets , inputs in zip ( batch [ " targets " ] , batch [ " logits " ] ) ] ) / batch_size
2024-06-12 03:28:59 +00:00
if self . metrics is not None :
2024-08-03 13:40:39 +00:00
metrics = self . metrics ( batch [ " logits " ] , batch [ " targets " ] , classifier_quant_levels )
2024-10-31 01:05:45 +00:00
stats [ " acc " ] [ name ] = metrics [ " acc " ]
2024-06-12 03:28:59 +00:00
else :
2024-10-31 01:05:45 +00:00
stats [ " acc " ] [ name ] = sum ( [ self . accuracy_metric ( inputs , targets ) for targets , inputs in zip ( batch [ " targets " ] , batch [ " logits " ] ) ] ) / batch_size
return LossStats ( loss , stats )
2024-05-19 16:23:56 +00:00
2024-04-17 02:04:48 +00:00
def forward (
self ,
inputs : list ,
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None ,
2024-04-17 02:04:48 +00:00
state : dict | list | None = None ,
2024-11-02 16:49:05 +00:00
layer_skip_variables : dict | None = None ,
2024-11-02 02:30:06 +00:00
output_attentions : bool = False ,
output_hidden_states : bool = False ,
2024-04-17 02:04:48 +00:00
) :
2024-11-02 16:49:05 +00:00
# return early if it's "good" enough"
# lambda because we need to capture the classifier_quant_levels and mask
2024-11-04 00:31:28 +00:00
exited_layer = self . n_layers
2024-11-02 16:49:05 +00:00
def layer_skip_lambda ( layer , logits ) :
2024-11-04 00:31:28 +00:00
nonlocal exited_layer
2024-11-02 16:49:05 +00:00
kwargs = {
2024-11-04 00:31:28 +00:00
" entropy_threshold " : 0.05 ,
" varentropy_threshold " : 0.05 ,
" min_layer " : self . n_layers / / 2 ,
2024-11-02 16:49:05 +00:00
" max_layer " : self . n_layers ,
}
kwargs . update ( layer_skip_variables )
# don't bother on early layers
if layer < kwargs [ " min_layer " ] :
return False
# bail if we want to force early layers
if kwargs [ " max_layer " ] < layer :
return True
# hidden states aren't normalized
x = self . model . norm ( logits )
# output projection layer with masking
if self . classifier is not None :
2024-11-03 15:58:29 +00:00
x = self . classifier ( x ) # * m
2024-11-02 16:49:05 +00:00
elif self . classifiers is not None :
2024-11-03 15:58:29 +00:00
logits = self . classifiers ( logits , levels = classifier_quant_levels ) # * m
2024-11-02 16:49:05 +00:00
# calculate metrics
metrics = calculate_entropix_metrics ( logits )
# exit early if "good enough""
2024-11-04 00:31:28 +00:00
early = metrics [ " logits_entropy " ] < = kwargs [ " entropy_threshold " ] and metrics [ " logits_varentropy " ] < = kwargs [ " varentropy_threshold " ]
if early :
exited_layer = layer
#print( layer, early, metrics )
return early
2024-11-02 16:49:05 +00:00
2024-11-12 22:41:58 +00:00
# derive quant levels from inputs if not provided
if quant_levels is None :
quant_levels = self . get_input ( inputs , " quant_level " )
2024-04-17 02:04:48 +00:00
x_list = self . inputs_to_embeddings ( inputs , quant_levels )
2024-11-03 02:00:21 +00:00
x , mask = list_to_tensor ( x_list )
m = mask . unsqueeze ( dim = - 1 )
2024-04-17 02:04:48 +00:00
2024-06-06 23:52:41 +00:00
training = self . training
2024-04-17 02:04:48 +00:00
device = x . device
batch_size = len ( x_list )
2024-06-12 03:28:59 +00:00
# pure AR
if quant_levels is None :
quant_levels = [ 0 for _ in range ( batch_size ) ]
2024-10-31 01:05:45 +00:00
2024-11-02 16:49:05 +00:00
# we only need hidden states if we're training with layerskip
if self . layerskip and training :
2024-10-31 01:05:45 +00:00
output_hidden_states = True
2024-04-16 00:54:32 +00:00
# pad our input and mask, but retain the original length by doing it after
if self . l_padding and x . shape [ 1 ] % self . l_padding != 0 :
# pad input
shape = list ( x . shape )
shape [ 1 ] = self . l_padding - shape [ 1 ] % self . l_padding
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
x = torch . cat ( [ x , padding ] , dim = 1 )
# pad mask
shape [ 2 ] = 1
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
2024-11-03 02:00:21 +00:00
mask = torch . cat ( [ mask , padding ] , dim = 1 )
2024-04-16 00:54:32 +00:00
2024-07-17 00:52:41 +00:00
# needs to be done here as we still have our raw inputs
2024-11-03 02:00:21 +00:00
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
position_ids = self . inputs_to_position_ids ( inputs , mask = mask ) if not self . unified_position_ids else None
2024-11-07 17:32:11 +00:00
2024-11-10 00:04:59 +00:00
tasks = [ self . get_input ( inputs , " task " , at = i ) for i in range ( batch_size ) ]
2024-11-10 00:12:54 +00:00
if self . inject_timestep_embedding :
timesteps = [ self . get_input ( inputs , " timestep " , at = i ) for i in range ( batch_size ) ]
timesteps = [ self . time_emb ( timestep ) if timestep is not None else None for i , timestep in enumerate ( timesteps ) ]
else :
timesteps = [ ]
2024-11-10 00:04:59 +00:00
classifier_quant_levels = [ - 1 if tasks [ i ] in self . special_tasks else l for i , l in enumerate ( quant_levels ) ]
2024-04-16 00:54:32 +00:00
2024-10-12 02:18:26 +00:00
output = self . _forward (
2024-04-16 00:54:32 +00:00
inputs = x ,
2024-11-03 02:00:21 +00:00
mask = mask ,
2024-04-16 00:54:32 +00:00
state = state ,
2024-07-17 00:52:41 +00:00
position_ids = position_ids ,
2024-10-12 02:18:26 +00:00
output_attentions = output_attentions ,
2024-10-31 01:05:45 +00:00
output_hidden_states = output_hidden_states ,
2024-11-02 16:49:05 +00:00
layer_skip_lambda = layer_skip_lambda if self . layerskip and layer_skip_variables else None ,
2024-11-10 00:04:59 +00:00
timesteps = timesteps ,
2024-04-16 00:54:32 +00:00
)
2023-08-02 21:53:35 +00:00
2024-10-12 02:18:26 +00:00
logits = output . logits
2024-10-31 01:05:45 +00:00
hidden_states = output . hidden_states
2024-10-12 02:18:26 +00:00
2024-11-04 00:31:28 +00:00
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
2024-11-02 16:49:05 +00:00
if self . classifier is not None :
2024-11-03 15:58:29 +00:00
logits = self . classifier ( logits ) # * m
2024-11-02 16:49:05 +00:00
if output . hidden_states :
for i , state in enumerate ( hidden_states ) :
2024-11-03 15:58:29 +00:00
hidden_states [ i ] = self . classifier ( hidden_states [ i ] ) # * m
2024-09-06 20:13:04 +00:00
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
2024-11-03 02:00:21 +00:00
elif self . classifiers is not None :
2024-11-03 15:58:29 +00:00
logits = self . classifiers ( logits , levels = classifier_quant_levels ) # * m
2024-06-12 03:28:59 +00:00
2024-10-31 01:05:45 +00:00
if hidden_states is not None :
2024-11-01 17:50:37 +00:00
for i , state in enumerate ( hidden_states ) :
2024-11-03 15:58:29 +00:00
hidden_states [ i ] = self . classifiers ( hidden_states [ i ] , levels = classifier_quant_levels ) # * m
2024-10-31 01:05:45 +00:00
2023-08-02 21:53:35 +00:00
# Remove padding
2024-10-12 02:18:26 +00:00
logits = [ hi [ : li ] for hi , li in zip ( logits , map ( len , x_list ) ) ]
2024-11-01 17:50:37 +00:00
if hidden_states is not None :
for i , state in enumerate ( hidden_states ) :
hidden_states [ i ] = [ hi [ : li ] for hi , li in zip ( hidden_states [ i ] , map ( len , x_list ) ) ]
2024-04-14 18:12:50 +00:00
2023-08-02 21:53:35 +00:00
# compute loss if the target is given
2024-04-17 02:04:48 +00:00
if training :
2024-11-01 17:50:37 +00:00
loss , stats = self . calc_loss ( inputs = inputs , logits = logits , quant_levels = quant_levels )
# compute it as an aux-loss
if self . layerskip :
early_exit_loss = { }
if not hasattr ( self , " training_steps " ) :
self . training_steps = 0
2024-10-31 01:05:45 +00:00
2024-11-01 17:50:37 +00:00
for i , state in enumerate ( hidden_states ) :
2024-10-31 01:05:45 +00:00
loss , stats = self . calc_loss ( inputs = inputs , logits = hidden_states [ i ] , quant_levels = quant_levels )
for k , v in loss . items ( ) :
2024-11-01 17:50:37 +00:00
K = f ' early_exit. { k } '
if K not in early_exit_loss :
early_exit_loss [ K ] = [ ]
early_exit_loss [ K ] . append ( v )
for k , v in early_exit_loss . items ( ) :
loss [ k ] = self . model . early_exit_loss ( losses = v , t = self . training_steps )
# to-do: instead make the cirriculum rely on samples processed instead of steps
self . training_steps + = 1 # batch_size
2024-10-31 01:05:45 +00:00
2024-05-19 16:23:56 +00:00
# include any additional losses (for example: MoE router)
2024-10-12 02:18:26 +00:00
if output . aux_loss is not None :
2024-10-31 01:05:45 +00:00
loss [ " aux_loss " ] = output . aux_loss
self . loss = loss
self . stats = stats
2023-09-09 01:30:54 +00:00
2024-10-12 02:18:26 +00:00
# rewrap, because we're modifying the logits here
2024-11-04 00:31:28 +00:00
return Logits ( logits , output . state , output . aux_loss , output . attentions , hidden_states , exited_layer )
2023-09-13 02:28:07 +00:00
def sample (
self ,
2024-07-30 00:15:07 +00:00
logits : list [ Tensor ] , # logit scores
2024-10-05 03:18:20 +00:00
prev_list : list [ Tensor ] | None = None , # previous tokens
2024-11-12 22:41:58 +00:00
quant_levels : int | list [ int ] | Tensor | None = None , # to-do: derive this from the prev_list
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
) :
# yikes
temperature = sampling_kwargs . get ( " temperature " , 1.0 )
min_temperature = sampling_kwargs . get ( " min_temperature " , - 1.0 )
top_k = sampling_kwargs . get ( " top_k " , - 100 )
top_p = sampling_kwargs . get ( " top_p " , 1.0 )
min_p = sampling_kwargs . get ( " min_p " , 0.0 )
2024-07-30 00:15:07 +00:00
# repetition penalty parameters
2024-11-12 02:21:16 +00:00
repetition_penalty = sampling_kwargs . get ( " repetition_penalty " , 1.0 )
repetition_penalty_decay = sampling_kwargs . get ( " repetition_penalty_decay " , 0.0 )
2024-07-30 00:15:07 +00:00
# length penalty parameters
2024-11-12 02:21:16 +00:00
length_penalty = sampling_kwargs . get ( " length_penalty " , 0.0 )
2024-07-30 00:15:07 +00:00
# beam sampling parameters
2024-11-12 02:21:16 +00:00
beam_width = sampling_kwargs . get ( " beam_width " , 0 )
2024-07-30 00:15:07 +00:00
# mirostat sampling parameters
2024-11-12 02:21:16 +00:00
mirostat = sampling_kwargs . get ( " mirostat " , None )
2024-07-30 00:15:07 +00:00
# DRY sampling parameters
2024-11-12 02:21:16 +00:00
dry_multiplier = sampling_kwargs . get ( " dry_multiplier " , 0.0 )
dry_base = sampling_kwargs . get ( " dry_base " , 1.75 )
dry_allowed_length = sampling_kwargs . get ( " dry_allowed_length " , 2 )
2024-11-13 04:30:09 +00:00
#
top_no = sampling_kwargs . get ( " top_no " , 1.0 )
#
2024-11-12 02:21:16 +00:00
attentions = sampling_kwargs . get ( " attentions " , None )
2024-11-02 16:49:05 +00:00
batch_size = len ( logits )
2023-10-10 22:02:33 +00:00
if min_temperature < 0 :
min_temperature = temperature
2024-06-06 00:50:06 +00:00
2024-11-12 03:40:19 +00:00
# pick last RVQ level
if prev_list is not None :
prev_list = [ prevs if prevs . dim ( ) == 1 else prevs [ : , - 1 ] for prevs in prev_list ]
2024-10-12 02:18:26 +00:00
scores = None
entropy = None
2024-11-03 15:58:29 +00:00
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
#logits = [ logit.to(device="cpu") for logit in logits ]
2024-10-12 02:18:26 +00:00
2024-10-13 17:01:12 +00:00
# (AR) entropix sampling
# we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token)
if attentions is not None and quant_levels is None :
# move to CPU for speedups
seq_lens = [ logit . shape [ 0 ] for logit in logits ]
attentions = torch . stack ( attentions , dim = 1 ) . to ( device = " cpu " ) # ( batch, layer, heads, seq_len, seq_len )
res = [ sample_entropix (
logit [ : seq_lens [ batch ] , : ] , # ( seq_len, vocab )
attentions [ batch , : , : , : seq_lens [ batch ] , : seq_lens [ batch ] ] , # (layer, heads, seq_len, seq_len )
temperature ,
top_k ,
top_p ,
min_p ,
) for batch , logit in enumerate ( logits ) ]
if res :
2024-11-08 19:34:39 +00:00
return Sampled ( [ r [ 0 ] for r in res ] , logits , scores , [ r [ 1 ] for r in res ] )
2024-11-02 16:49:05 +00:00
"""
elif quant_levels is None :
seq_lens = [ logit . shape [ 0 ] for logit in logits ]
entropy = [ calculate_entropix_metrics (
logit [ : seq_lens [ batch ] , : ] , # ( seq_len, vocab )
#attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
) for batch , logit in enumerate ( logits ) ]
"""
2024-10-13 17:01:12 +00:00
2023-09-09 01:30:54 +00:00
# (NAR) return the entire generated response
2024-07-30 00:15:07 +00:00
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None : # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
2024-10-13 04:53:13 +00:00
seq_lens = map ( len , prev_list )
logits = [ logit [ - l : ] for logit , l in zip ( logits , seq_lens ) ]
2023-09-09 01:30:54 +00:00
# (AR chunkwise) return the last chunkwise piece
2024-06-07 01:51:31 +00:00
elif self . causal :
2024-10-13 04:53:13 +00:00
seq_lens = [ logit . shape [ 0 ] - self . causal_size for logit in logits ]
2024-06-07 01:51:31 +00:00
logits = [ logit [ - self . causal_size : ] for logit in logits ]
2023-09-09 01:30:54 +00:00
2024-08-03 13:40:39 +00:00
# (NAR) disable stop token
if quant_levels is not None and " ar " in self . capabilities :
2024-09-06 01:43:20 +00:00
logits = [ ban_tokens ( logit , tokens = [ self . stop_token ] ) for logit , l in zip ( logits , map ( len , prev_list ) ) ]
2024-08-03 13:40:39 +00:00
# (AR-len) disable extraneous tokens
2024-11-07 15:10:18 +00:00
"""
2024-08-03 13:40:39 +00:00
if quant_levels is None and " len " in self . capabilities :
2024-09-06 01:43:20 +00:00
logits = [ ban_tokens ( logit , tokens = [ * range ( 11 , logit . shape [ - 1 ] ) ] ) for logit , l in zip ( logits , map ( len , prev_list ) ) ]
2024-11-07 15:10:18 +00:00
"""
2023-10-09 19:46:17 +00:00
2024-07-31 01:53:51 +00:00
# perform repetition penalizing
2024-11-09 04:05:41 +00:00
if prev_list is not None and repetition_penalty != 1.0 :
2024-11-12 22:41:58 +00:00
logits = [ reptition_penalize ( logit , previous = prevs , factor = repetition_penalty , decay = repetition_penalty_decay ) for logit , prevs in zip ( logits , prev_list ) ]
2024-07-31 01:53:51 +00:00
2023-09-09 01:43:36 +00:00
# (AR) perform length penalizing
2024-10-05 03:30:47 +00:00
if quant_levels is None and self . causal and prev_list is not None and length_penalty != 0.0 :
2024-09-06 01:43:20 +00:00
logits = [ length_penalize ( logit , length = l + 1 , factor = length_penalty , token = self . stop_token ) for logit , l in zip ( logits , map ( len , prev_list ) ) ]
2023-09-09 01:30:54 +00:00
2024-10-12 03:36:06 +00:00
# perform min_p filtering of our logits
if min_p > 0.0 :
logits = [ min_p_filtering ( logit , min_p = min_p ) for logit in logits ]
2023-09-09 01:30:54 +00:00
# perform top_k/top_p filtering of our logits
2023-09-13 02:28:07 +00:00
if top_k > 0 or top_p < 1.0 :
logits = [ top_k_top_p_filtering ( logit , top_k = top_k , top_p = top_p ) for logit in logits ]
2023-10-10 22:02:33 +00:00
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
2024-01-27 01:41:12 +00:00
# epsilon float comparison because I don't trust Python
2023-10-10 22:02:33 +00:00
if abs ( temperature - min_temperature ) > = 0.001 :
logits = [ dynamic_temperature ( logit , temperature = temperature , min_temperature = min_temperature ) for logit in logits ]
2024-10-18 21:55:00 +00:00
elif temperature > 0.0 :
2023-10-09 18:01:40 +00:00
logits = [ logit / temperature for logit in logits ]
2024-11-13 04:30:09 +00:00
# do top-no logit processing
if top_no > 0.0 :
logits = [ top_no_logits_processing ( logit ) for logit in logits ]
2024-07-30 00:15:07 +00:00
# do DRY sampling
2024-10-22 23:12:39 +00:00
if dry_multiplier > 0.0 and prev_list is not None :
2024-11-12 03:40:19 +00:00
logits = [ dry_sampling ( logit , previous = prevs , factor = dry_multiplier , base = dry_base , allowed_length = dry_allowed_length ) for logit , prevs in zip ( logits , prev_list ) ]
2024-07-30 00:15:07 +00:00
2023-09-18 23:55:41 +00:00
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None :
# mirostat sampling
2024-10-12 02:18:26 +00:00
scores = [ mirostat_sample ( logit , state = state ) for logit , state in zip ( logits , mirostat ) ]
res = [ state [ " token " ] for state in scores ]
2023-09-13 02:28:07 +00:00
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
2023-09-13 18:19:11 +00:00
# returns the logit scores as well to be P-concatted with the previous scores
2023-09-13 02:28:07 +00:00
# to-do: not naively implement beam searching
2024-10-12 02:18:26 +00:00
elif beam_width > 1 :
2023-09-13 18:19:11 +00:00
candidates = top_k_logits_list ( logits , beam_width )
2023-10-11 17:25:31 +00:00
res = [ torch . tensor ( token , dtype = torch . int16 ) . unsqueeze ( dim = - 1 ) for batch , token in candidates ]
scores = [ logits [ batch ] . flatten ( ) [ token ] for batch , token in candidates ]
2024-10-12 02:18:26 +00:00
# basic sampling
else :
2024-10-18 21:55:00 +00:00
# argmax instead
if temperature < = 0.0 :
2024-11-10 04:57:34 +00:00
res = [ logit . argmax ( dim = - 1 ) for logit in logits ]
2024-10-18 21:55:00 +00:00
else :
res = [ Categorical ( logits = logit ) . sample ( ) for logit in logits ]
2023-09-13 02:28:07 +00:00
2024-11-02 16:49:05 +00:00
# calculate token probabilities
2024-11-10 18:19:48 +00:00
scores = [
[ F . softmax ( logit [ i , : ] , dim = - 1 ) [ token ] . item ( ) for i , token in enumerate ( tokens ) ]
for logit , tokens in zip ( logits , res )
]
2024-11-02 16:49:05 +00:00
2024-11-08 19:34:39 +00:00
return Sampled ( res , logits , scores , entropy )