2023-08-02 21:53:35 +00:00
import math
import torch
import torch . nn . functional as F
import traceback
2023-09-13 02:28:07 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-08-02 21:53:35 +00:00
2024-05-04 18:07:45 +00:00
from typing import Literal , overload , Optional , Tuple
2023-08-02 21:53:35 +00:00
from functools import partial
from einops import rearrange
from torch import Tensor , einsum , nn
2024-04-17 02:04:48 +00:00
from torch . nn import Embedding
2023-08-02 21:53:35 +00:00
from torch . distributions import Categorical
from torch . nn . utils . rnn import pad_sequence
from torch . utils . checkpoint import checkpoint
from torchmetrics . classification import BinaryAccuracy , MulticlassAccuracy , MulticlassPrecision
2024-06-06 01:30:43 +00:00
from . arch import *
2024-05-10 01:28:20 +00:00
from . . utils import wrapper as ml
2023-10-11 17:25:31 +00:00
from . . samplers import reptition_penalize , length_penalize , top_k_top_p_filtering , dynamic_temperature , top_k_logits_list , mirostat_sample
2023-08-02 21:53:35 +00:00
def _create_mask ( l , device ) :
""" 1 is valid region and 0 is invalid. """
seq = torch . arange ( max ( l ) , device = device ) . unsqueeze ( 0 ) # (1 t)
stop = torch . tensor ( l , device = device ) . unsqueeze ( 1 ) # (b 1)
return ( seq < stop ) . float ( ) # (b t)
def _join ( x : tuple [ Tensor ] , sep : Tensor ) :
"""
Args :
x : ( k t d )
sep : ( d )
"""
ret = x [ 0 ]
for i in range ( 1 , len ( x ) ) :
ret = torch . cat ( ( ret , sep [ None ] , x [ i ] ) , dim = 0 )
return ret
def list_to_tensor ( x_list : list [ Tensor ] , pattern = " t b c -> b t c " ) :
"""
Args :
x_list : [ ( t d ) ]
Returns :
x : ( ? ? ? )
m : ( ? ? ? ) , same as x
"""
l = list ( map ( len , x_list ) )
x = rearrange ( pad_sequence ( x_list ) , pattern )
m = _create_mask ( l , x_list [ 0 ] . device )
m = m . t ( ) . unsqueeze ( - 1 ) # (t b 1)
m = rearrange ( m , pattern )
m = m . to ( x )
return x , m
2023-09-09 01:30:54 +00:00
# automagically parses a batch-list and returns it as a list
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
class Embedding ( nn . Embedding ) :
def forward ( self , x_list : list [ Tensor ] ) - > list [ Tensor ] :
if len ( x_list ) == 0 :
return [ ]
return super ( ) . forward ( torch . cat ( x_list ) ) . split ( [ * map ( len , x_list ) ] )
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
2024-04-29 23:24:05 +00:00
# Deprecated implementation
2023-09-13 18:19:11 +00:00
class MultiEmbedding ( nn . Module ) :
2023-09-08 06:03:24 +00:00
def __init__ ( self , max_n_levels , n_tokens , token_dim , monolithic = False ) :
2023-09-16 05:26:13 +00:00
super ( ) . __init__ ( )
2023-09-08 20:36:26 +00:00
self . monolithic = monolithic
2023-08-02 21:53:35 +00:00
self . max_n_levels = max_n_levels
self . n_tokens = n_tokens
2023-09-08 20:36:26 +00:00
self . weight = nn . Parameter ( torch . randn ( max_n_levels , n_tokens , token_dim ) )
2023-08-02 21:53:35 +00:00
2023-09-07 22:08:38 +00:00
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
def forward ( self , x_list : list [ Tensor ] , quant_levels : Tensor | None = None ) - > list [ Tensor ] :
2023-08-02 21:53:35 +00:00
if len ( x_list ) == 0 :
return [ ]
2023-09-08 20:36:26 +00:00
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
2023-09-08 06:03:24 +00:00
if self . monolithic :
2023-09-08 20:36:26 +00:00
w = self . weight [ : 1 ] if quant_levels is None else self . weight [ 1 : ]
2023-09-08 06:03:24 +00:00
else :
w = self . weight
2023-08-02 21:53:35 +00:00
2023-09-11 19:13:42 +00:00
padded_x_list = [ ]
2023-08-02 21:53:35 +00:00
2023-09-08 20:36:26 +00:00
for i , xi in enumerate ( x_list ) :
2023-08-02 21:53:35 +00:00
xi = F . one_hot ( xi . to ( torch . int64 ) , num_classes = self . n_tokens ) # t l' k
2023-09-11 19:13:42 +00:00
wi = w . shape [ 0 ] - xi . shape [ 1 ]
xi = F . pad ( xi , ( 0 , 0 , 0 , wi ) ) # t l k
2023-08-02 21:53:35 +00:00
padded_x_list . append ( xi . to ( w ) )
2023-09-11 19:13:42 +00:00
x = torch . cat ( padded_x_list ) # n l k
x = einsum ( " l k d, n l k -> n d " , w , x )
2023-08-02 21:53:35 +00:00
x_list = x . split ( [ * map ( len , x_list ) ] )
2023-09-07 21:48:02 +00:00
2023-09-11 19:13:42 +00:00
return x_list
2023-09-07 21:48:02 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
2023-09-11 19:13:42 +00:00
class AudioEmbedding ( nn . Module ) :
2024-04-29 23:24:05 +00:00
def __init__ (
self ,
l_tokens : int , # list of number of tokens (needed because AR resps includes stop token)
token_dim : int , # dimensionality of the embedding
levels : int | None = None , # number of RVQ-bins (I don't remember the specifics)
sums : bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
) :
2023-09-07 14:14:03 +00:00
super ( ) . __init__ ( )
2024-04-29 23:24:05 +00:00
# array of embeddings
# proms are [0, prom_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
2023-09-21 00:10:59 +00:00
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
2024-04-29 23:24:05 +00:00
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
2024-01-25 18:18:48 +00:00
self . weight = nn . ParameterList ( [ nn . Parameter ( torch . Tensor ( [ 1 ] ) ) for i in range ( levels ) ] ) if levels is not None else None
2024-04-29 23:24:05 +00:00
#
self . sums = sums
2023-09-07 14:14:03 +00:00
2024-04-17 02:04:48 +00:00
def forward ( self , xi : Tensor , quant_levels : Tensor | None = None ) - > Tensor :
# prom
if quant_levels is None and xi . shape [ - 1 ] > 1 :
2024-04-29 23:24:05 +00:00
if self . sums :
x = sum ( [ self . embeddings [ k ] ( xi [ : , k ] ) * ( self . weight [ k ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
else :
k = 0 # only use the most significant RVQ bin level for the input prom
x = self . embeddings [ k ] ( xi [ : , k ] ) * ( self . weight [ k ] if self . weight is not None else 1 )
2024-04-17 02:04:48 +00:00
# AR resp
elif quant_levels is None or quant_levels == 0 :
2024-05-31 01:50:45 +00:00
x = self . embeddings [ 0 ] ( xi if len ( xi . shape ) == 1 else xi [ : , 0 ] )
2024-04-17 02:04:48 +00:00
# NAR resp
else :
2024-04-29 23:24:05 +00:00
if self . sums :
x = sum ( [ self . embeddings [ k + 1 ] ( xi [ : , k ] ) * ( self . weight [ k + 1 ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
else :
k = xi . shape [ - 1 ] - 1 # only use the previous RVQ bin level for the current resp embedding
x = self . embeddings [ k + 1 ] ( xi [ : , k ] ) * ( self . weight [ k + 1 ] if self . weight is not None else 1 )
2024-04-17 02:04:48 +00:00
return x
2023-09-07 21:48:02 +00:00
2023-08-02 21:53:35 +00:00
class Base ( nn . Module ) :
@property
def causal ( self ) - > bool :
raise NotImplementedError
@property
def arch_type ( self ) - > str :
raise NotImplementedError
@property
def norm_type ( self ) :
raise NotImplementedError
@property
def n_prom_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 20:06:33 +00:00
2023-09-06 23:58:35 +00:00
@property
def n_resp_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 20:06:33 +00:00
@property
def n_max_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 01:58:07 +00:00
2023-09-21 00:10:59 +00:00
@property
def n_langs ( self ) - > int :
raise NotImplementedError
2024-04-16 00:54:32 +00:00
2023-08-19 01:58:07 +00:00
@property
def n_tasks ( self ) - > int :
raise NotImplementedError
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
@property
def n_tones ( self ) - > int :
raise NotImplementedError
2023-09-02 01:58:29 +00:00
@property
def recurrent_chunk_size ( self ) - > int :
raise NotImplementedError
2023-09-21 00:10:59 +00:00
@property
def rotary_embedding_base ( self ) - > float :
return 10000
2023-09-04 03:46:08 +00:00
@property
def interleave ( self ) - > bool :
return False
2023-09-02 01:58:29 +00:00
2023-09-07 00:33:39 +00:00
@property
2023-09-07 21:48:02 +00:00
def monolithic ( self ) - > bool :
2023-09-07 00:33:39 +00:00
return False
2023-09-07 14:14:03 +00:00
@property
2023-09-11 19:13:42 +00:00
def version ( self ) - > int :
return 1
2023-09-07 14:14:03 +00:00
2023-09-06 23:58:35 +00:00
@property
def stop_token ( self ) :
if not self . causal :
raise ValueError ( " Not using stop token! " )
2024-06-06 00:50:06 +00:00
return self . n_audio_tokens
2023-09-06 23:58:35 +00:00
@property
def ignore_index ( self ) :
return - 100
2024-05-19 16:23:56 +00:00
def loss_factor ( self , k ) :
2024-06-06 14:48:43 +00:00
if self . config is None :
2024-05-19 16:23:56 +00:00
return 1.0
2024-06-06 14:48:43 +00:00
return self . config . loss_factors [ k ] if k in self . config . loss_factors else 1.0
2024-05-19 16:23:56 +00:00
2023-08-02 21:53:35 +00:00
def __init__ (
self ,
2024-06-06 00:50:06 +00:00
n_text_tokens : int = 256 ,
n_audio_tokens : int = 1024 ,
2023-08-02 21:53:35 +00:00
d_model : int = 512 ,
n_heads : int = 8 ,
n_layers : int = 12 ,
p_dropout : float = 0.1 ,
2023-09-07 00:33:39 +00:00
2024-04-09 01:14:51 +00:00
n_experts : int = 1 ,
l_padding : int = 0 ,
2023-12-21 00:45:58 +00:00
2024-02-01 03:48:36 +00:00
training = True ,
2023-09-04 03:46:08 +00:00
config = None ,
2023-08-02 21:53:35 +00:00
) :
super ( ) . __init__ ( )
2024-02-01 03:48:36 +00:00
self . training = training
2024-06-06 14:48:43 +00:00
self . config = config
self . gradient_checkpointing = self . config . gradient_checkpointing if self . config is not None else True
2023-09-04 03:46:08 +00:00
2024-06-06 00:50:06 +00:00
self . n_text_tokens = n_text_tokens
self . n_audio_tokens = n_audio_tokens
2023-08-02 21:53:35 +00:00
self . d_model = d_model
self . n_heads = n_heads
self . n_layers = n_layers
2023-12-21 00:45:58 +00:00
self . n_experts = n_experts
2024-04-09 01:14:51 +00:00
self . l_padding = l_padding
2023-08-02 21:53:35 +00:00
# +1 to include the stop token
2024-06-06 00:50:06 +00:00
n_prom_tokens = n_audio_tokens
2024-06-06 14:48:43 +00:00
n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability
2023-08-02 21:53:35 +00:00
2024-06-06 00:50:06 +00:00
self . text_emb = Embedding ( n_text_tokens , d_model )
2023-10-22 14:01:47 +00:00
self . langs_emb = None
2024-04-16 00:54:32 +00:00
self . tones_emb = None
2023-10-22 14:01:47 +00:00
self . tasks_emb = None
2024-06-05 04:23:31 +00:00
self . rvq_level_emb = None
2023-09-07 00:33:39 +00:00
2023-09-11 19:13:42 +00:00
if self . version == 1 : # legacy
2023-09-21 00:10:59 +00:00
n_prom_tokens + = ( self . n_tasks - 1 ) # old models have the task tokens in the prom
2023-09-11 19:13:42 +00:00
self . proms_emb = MultiEmbedding ( self . n_prom_levels , n_prom_tokens , d_model )
self . resps_emb = MultiEmbedding ( self . n_resp_levels , n_resp_tokens , d_model , monolithic = self . monolithic )
else :
2023-09-21 00:10:59 +00:00
# [1024] * 8
2024-04-29 23:24:05 +00:00
self . proms_emb = AudioEmbedding (
[ n_prom_tokens ] * self . n_prom_levels , d_model ,
levels = self . n_prom_levels if self . version > 3 else None ,
2024-06-06 14:48:43 +00:00
sums = self . config . audio_embedding_sums if self . config is not None else True ,
2024-04-29 23:24:05 +00:00
)
2024-06-06 00:50:06 +00:00
# [1024 + STOP] + [1024] * 8
2024-04-29 23:24:05 +00:00
self . resps_emb = AudioEmbedding (
[ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 ) , d_model ,
levels = self . n_resp_levels if self . version > 3 else None ,
2024-06-06 14:48:43 +00:00
sums = self . config . audio_embedding_sums if self . config is not None else True
2024-04-29 23:24:05 +00:00
)
2023-09-21 00:10:59 +00:00
2024-06-06 00:50:06 +00:00
# useless since I actually removed using these with the input processing overhaul...
2023-10-12 01:38:40 +00:00
if self . version > = 3 :
2023-10-22 14:01:47 +00:00
self . langs_emb = Embedding ( self . n_langs , d_model ) if self . n_langs > 0 else None
self . tasks_emb = Embedding ( self . n_tasks , d_model ) if self . n_tasks > 0 else None
2024-06-06 00:50:06 +00:00
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
2024-04-16 00:54:32 +00:00
if self . version > = 4 :
self . tones_emb = Embedding ( self . n_tones , d_model ) if self . n_tones > 0 else None
2023-08-02 21:53:35 +00:00
2024-06-06 00:50:06 +00:00
# mamba requires this if a model does both AR and NAR tasks
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
# this ***might*** let me also unify the proms_emb and resps_embedding
2024-06-05 04:23:31 +00:00
if self . version > = 5 :
self . rvq_level_emb = Embedding ( self . n_resp_levels , d_model )
2024-06-06 00:50:06 +00:00
# this would be nicer to be a stop token or live inside an embedding
2023-08-02 21:53:35 +00:00
self . sep = nn . Parameter ( torch . randn ( d_model ) )
2023-08-27 00:53:23 +00:00
2024-05-10 01:28:20 +00:00
# ick, there has to be a better way
2024-06-06 14:48:43 +00:00
hf_attention = self . config . attention if self . config is not None else None
2024-05-11 22:14:05 +00:00
2024-06-06 14:48:43 +00:00
if self . config . attention == " auto " :
2024-05-11 22:14:05 +00:00
if " flash " in AVAILABLE_ATTENTIONS :
2024-06-06 14:48:43 +00:00
self . config . attention = " flash "
2024-05-11 22:14:05 +00:00
elif " xformers " in AVAILABLE_ATTENTIONS :
2024-06-06 14:48:43 +00:00
self . config . attention = " xformers "
2024-05-11 22:14:05 +00:00
else :
2024-06-06 14:48:43 +00:00
self . config . attention = " mem_efficient "
2024-05-11 22:14:05 +00:00
2024-06-06 14:48:43 +00:00
if self . config . attention in [ " xformers " , " mem_efficient " , " math " , " flash " ] :
2024-05-11 22:14:05 +00:00
hf_attention = None
2024-06-06 14:48:43 +00:00
if self . config . attention not in AVAILABLE_ATTENTIONS :
raise ValueError ( f " Requesting attention ` { self . config . attention } ` but is not available. Currently available: { AVAILABLE_ATTENTIONS } " )
2024-05-10 01:28:20 +00:00
2023-08-02 21:53:35 +00:00
if self . arch_type == " transformer " :
self . sin_emb = SinusoidalEmbedding ( d_model )
self . blocks = nn . ModuleList ( [ TransformerBlock (
d_model = d_model ,
n_heads = n_heads ,
2024-02-01 03:48:36 +00:00
p_dropout = p_dropout if training else 0.0 ,
2023-08-19 01:58:07 +00:00
causal = self . causal ,
2023-08-02 21:53:35 +00:00
norm_type = self . norm_type ,
n_levels = self . n_resp_levels ,
) for _ in range ( n_layers ) ] )
2024-05-10 01:28:20 +00:00
elif self . arch_type in [ " mistral " , " mixtral " ] :
2024-02-01 03:48:36 +00:00
if n_experts < = 1 :
self . model = MistralModel ( MistralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-06 14:48:43 +00:00
num_key_value_heads = self . config . kv_heads if self . config . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-06 14:48:43 +00:00
num_key_value_heads = self . config . kv_heads if self . config . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
2024-05-10 04:25:44 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-05-10 04:25:44 +00:00
2024-05-12 13:22:39 +00:00
#if training:
# self.model.training = True
2023-12-23 01:27:36 +00:00
elif self . arch_type == " llama " :
if n_experts < = 1 :
self . model = LlamaModel ( LlamaConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
2024-05-10 04:15:52 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-05-10 04:15:52 +00:00
2024-05-12 13:22:39 +00:00
#if training:
# self.model.training = True
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-12-26 03:20:32 +00:00
kwargs = dict (
2023-12-23 01:27:36 +00:00
vocab_size = n_resp_tokens ,
2023-08-02 21:53:35 +00:00
decoder_embed_dim = d_model ,
2023-10-05 21:39:46 +00:00
decoder_value_embed_dim = d_model * 2 ,
2023-08-02 21:53:35 +00:00
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
2024-02-01 03:48:36 +00:00
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2023-10-05 21:39:46 +00:00
activation_fn = " gelu " ,
2024-05-10 01:28:20 +00:00
use_layernorm = self . version < 3 ,
use_biases = self . version < 3 ,
use_glu = self . version > = 3 ,
2023-08-02 21:53:35 +00:00
2023-09-02 01:58:29 +00:00
chunkwise_recurrent = self . causal and self . recurrent_chunk_size > 0 ,
recurrent_chunkwise_size = self . recurrent_chunk_size if self . causal else 0 ,
2023-08-02 21:53:35 +00:00
no_output_layer = True ,
decoder_normalize_before = True ,
2023-09-21 00:10:59 +00:00
rotary_embedding_base = self . rotary_embedding_base , # 10000
2023-12-26 03:20:32 +00:00
)
if n_experts > 1 :
kwargs . update ( dict (
use_xmoe = True ,
moe_freq = 1 ,
moe_expert_count = n_experts ,
moe_gating_use_fp32 = False ,
) )
2023-12-21 00:45:58 +00:00
2023-12-26 03:20:32 +00:00
self . model = RetNetDecoder ( RetNetConfig ( * * kwargs ) )
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
kwargs = dict (
vocab_size = n_resp_tokens ,
decoder_embed_dim = d_model ,
decoder_value_embed_dim = d_model * 2 ,
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2024-04-09 01:14:51 +00:00
activation_fn = " gelu " ,
use_glu = False , # self.version >= 3,
recurrent_chunk_size = self . recurrent_chunk_size if self . causal else 0 ,
decoder_normalize_before = True ,
deepnorm = False ,
subln = True ,
)
self . model = RetNetDecoder_HF ( RetNetConfig_HF ( * * kwargs ) )
2024-05-12 13:22:39 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-12 13:22:39 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
self . model = BitNetTransformer (
num_tokens = n_resp_tokens ,
dim = d_model ,
depth = n_layers ,
heads = n_heads ,
ff_mult = 4 ,
2024-06-04 02:28:49 +00:00
gradient_checkpointing = self . gradient_checkpointing ,
2024-03-01 02:29:17 +00:00
)
2024-06-05 03:41:22 +00:00
elif self . arch_type in [ " mamba " , " mamba2 " ] :
self . model = MambaMixelModel (
vocab_size = n_resp_tokens ,
d_model = d_model ,
n_layer = n_layers * 2 ,
d_intermediate = 0 ,
ssm_cfg = { " layer " : " Mamba2 " , " chunk_size " : 64 } if self . arch_type == " mamba2 " else { } ,
rms_norm = True ,
fused_add_norm = True ,
residual_in_fp32 = True ,
#attn_layer_idx=attn_layer_idx,
#attn_cfg=attn_cfg,
#initializer_cfg=initializer_cfg,
)
self . model . gradient_checkpointing = self . gradient_checkpointing
2024-03-01 02:29:17 +00:00
else :
raise RuntimeError ( f ' Unknown arch specified: { self . arch_type } ' )
2023-08-04 01:26:36 +00:00
2024-06-06 14:48:43 +00:00
if self . config . attention in [ " xformers " , " auto " , " mem_efficient " , " math " , " flash " ] :
self . model = ml . replace_attention ( self . model , klass = LlamaAttention , target = LlamaAttention_Base , mode = self . config . attention )
2024-05-10 01:28:20 +00:00
2023-08-02 21:53:35 +00:00
self . classifier = nn . Linear ( d_model , n_resp_tokens )
self . accuracy_metric = MulticlassAccuracy (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
self . precision_metric = MulticlassPrecision (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
2024-04-16 00:54:32 +00:00
def _forward (
2023-08-02 21:53:35 +00:00
self ,
2024-04-16 00:54:32 +00:00
inputs ,
mask = None ,
state = None ,
2023-08-02 21:53:35 +00:00
) :
2024-04-16 00:54:32 +00:00
x = inputs
m = mask . squeeze ( - 1 ) . int ( )
2023-12-23 01:27:36 +00:00
aux_loss = None
2024-04-14 18:12:50 +00:00
"""
# Broken
if state is not None and ( self . arch_type == " retnet " or self . arch_type == " retnet-hf " ) :
2023-09-02 01:58:29 +00:00
# prefill
if len ( state ) == 0 :
2023-09-05 20:38:21 +00:00
prefill_size = x . shape [ 1 ]
# run the initial prompt to fill the KV cache
2024-04-09 01:14:51 +00:00
if self . arch_type == " retnet " :
for n in range ( prefill_size ) :
xi = x [ : , n , : ] . unsqueeze ( 1 )
self . model ( xi , incremental_state = state , token_embeddings = xi , features_only = True )
elif self . arch_type == " retnet-hf " :
2024-04-14 18:12:50 +00:00
state = None
2024-04-09 01:14:51 +00:00
for n in range ( prefill_size ) :
xi = x [ : , n , : ] . unsqueeze ( 1 )
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2024-04-14 18:12:50 +00:00
inputs_embeds = xi ,
past_key_values = state ,
use_cache = True ,
forward_impl = ' recurrent ' ,
2024-04-09 01:14:51 +00:00
# return_dict=True,
)
out = self . model ( * * kwargs )
2024-04-14 18:12:50 +00:00
state = out . past_key_values
2023-09-02 01:58:29 +00:00
# grab last token(s)
x = x [ : , - 1 , : ] . unsqueeze ( 1 )
2024-04-14 18:12:50 +00:00
"""
2023-12-23 01:27:36 +00:00
# HF transformer derived model
2024-04-16 00:54:32 +00:00
if self . arch_type in [ " llama " , " mistral " , " mixtral " ] :
2023-12-23 01:27:36 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2023-12-23 01:27:36 +00:00
inputs_embeds = x ,
2024-02-01 03:48:36 +00:00
past_key_values = state ,
2024-04-14 18:12:50 +00:00
use_cache = True ,
2024-02-01 03:48:36 +00:00
# return_dict=True,
2023-12-23 01:27:36 +00:00
)
2024-02-01 03:48:36 +00:00
if self . n_experts > 1 and targ_list is not None :
2023-12-23 01:27:36 +00:00
kwargs [ " output_router_logits " ] = True
2023-09-02 01:58:29 +00:00
2023-12-23 01:27:36 +00:00
t = self . model ( * * kwargs )
2024-02-01 03:48:36 +00:00
2023-12-23 01:27:36 +00:00
x = t [ 0 ]
2024-02-01 03:48:36 +00:00
if state is not None :
state = t [ 1 ]
if self . n_experts > 1 and targ_list is not None :
2023-12-23 01:27:36 +00:00
router_logits = t [ - 1 ]
aux_loss = self . model . config . router_aux_loss_coef * load_balancing_loss_func ( router_logits , self . model . config . num_local_experts , self . model . config . num_experts_per_tok )
elif self . arch_type == " transformer " :
2023-09-12 21:04:45 +00:00
# ensures we specify a quant_level for the transformer implementation's AdaLN
2023-09-07 14:14:03 +00:00
l = torch . zeros ( ( batch_size , ) , dtype = torch . int32 ) if quant_levels is None else quant_levels
l = l . to ( device )
2023-09-12 21:04:45 +00:00
# inject position information
x = self . sin_emb . add_pe ( x )
# pass our inputs through the transformer
2023-08-02 21:53:35 +00:00
for block in self . blocks :
2024-04-16 00:54:32 +00:00
x = block ( x , m , l )
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-09-12 21:04:45 +00:00
# pass our inputs through the RetNet
2023-12-23 01:27:36 +00:00
x , _ = self . model ( x , incremental_state = state , token_embeddings = x , features_only = True )
2023-12-26 03:20:32 +00:00
if _ is not None and " l_aux " in _ and self . n_experts > 1 :
2023-12-23 01:27:36 +00:00
aux_loss = torch . sum ( torch . stack ( [ t for t in _ [ " l_aux " ] if t is not None ] ) ) * 0.001
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
2024-04-14 18:12:50 +00:00
first = state is None or len ( state ) == 0
2024-04-09 01:14:51 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2024-04-14 18:12:50 +00:00
inputs_embeds = x if first else x [ : , - 1 , : ] . unsqueeze ( 1 ) ,
past_key_values = None if first else state ,
use_cache = True ,
forward_impl = ' parallel ' if first else ' recurrent ' ,
return_dict = True ,
2024-04-09 01:14:51 +00:00
)
2024-04-14 18:12:50 +00:00
out = self . model ( * * kwargs )
x = out . last_hidden_state
2024-04-09 01:14:51 +00:00
if state is not None :
2024-04-14 18:12:50 +00:00
state = out . past_key_values
2024-06-05 03:41:22 +00:00
elif self . arch_type in [ " mamba " , " mamba2 " ] :
x = self . model ( hidden_states = x )
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
x = self . model ( x )
2024-04-14 18:12:50 +00:00
2023-09-12 21:04:45 +00:00
# output projection layer with masking
2024-04-16 00:54:32 +00:00
x = self . classifier ( x ) * mask
2024-04-14 18:12:50 +00:00
2024-04-16 00:54:32 +00:00
return x , state , aux_loss
2024-04-17 02:04:48 +00:00
def inputs (
2024-04-16 00:54:32 +00:00
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
targ_list : list [ Tensor ] | None = None ,
2024-04-17 02:04:48 +00:00
2024-04-16 00:54:32 +00:00
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
2024-06-05 04:23:31 +00:00
quant_levels : Tensor | None = None
2024-04-16 00:54:32 +00:00
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
2024-04-17 02:04:48 +00:00
inputs = [ [ ] for _ in range ( batch_size ) ]
for i in range ( batch_size ) :
2024-06-05 04:23:31 +00:00
quant_level = quant_levels [ i ] if quant_levels is not None else 0
2024-04-17 02:04:48 +00:00
if text_list is not None :
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
2024-06-05 04:23:31 +00:00
2024-06-06 01:53:10 +00:00
if self . rvq_level_emb is not None :
inputs [ i ] . append ( ( " quant_level " , torch . Tensor ( [ quant_level ] ) . to ( device = device , dtype = torch . int16 ) ) )
2024-06-05 04:23:31 +00:00
2024-04-17 02:04:48 +00:00
if proms_list is not None :
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
if resps_list is not None :
inputs [ i ] . append ( ( " resp " , resps_list [ i ] ) )
if targ_list is not None :
inputs [ i ] . append ( ( " targ " , targ_list [ i ] ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return inputs
def inputs_to_embeddings (
self ,
inputs : list ,
quant_levels : Tensor | None = None
) :
x_list = [ ]
2024-05-29 00:29:54 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-04-17 02:04:48 +00:00
batch = [ ]
2024-06-05 04:23:31 +00:00
quant_level = quant_levels [ batch_index ] if quant_levels is not None else 0
2024-05-29 00:29:54 +00:00
for name , input in batch_input :
2024-04-17 02:04:48 +00:00
embedding = None
if name == " text " :
embedding = self . text_emb ( input )
2024-06-05 15:30:04 +00:00
elif name == " quant_level " and self . rvq_level_emb is not None :
2024-06-05 04:23:31 +00:00
embedding = self . rvq_level_emb ( input )
2024-06-05 15:30:04 +00:00
elif name == " lang " and self . langs_emb is not None :
2024-04-17 02:04:48 +00:00
embedding = self . langs_emb ( input )
elif name == " prom " :
embedding = self . proms_emb ( input )
2024-06-05 15:30:04 +00:00
elif name == " tone " and self . tones_emb is not None :
2024-04-17 02:04:48 +00:00
embedding = self . tones_emb ( input )
elif name == " resp " :
2024-05-29 00:29:54 +00:00
embedding = self . resps_emb ( input , quant_level )
2024-04-17 02:04:48 +00:00
else :
continue
batch . append ( embedding )
x_list . append ( _join ( batch , self . sep ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return x_list
2024-05-29 00:29:54 +00:00
def calc_loss (
2024-04-17 02:04:48 +00:00
self ,
inputs : list ,
2024-05-29 00:29:54 +00:00
logits ,
quant_levels : Tensor | None = None ,
2024-04-17 02:04:48 +00:00
) :
2024-05-29 00:29:54 +00:00
# old, "naive" way, no loss factoring
2024-06-06 14:48:43 +00:00
if not self . config . loss_factors :
2024-05-29 00:29:54 +00:00
target_list = [ ]
2024-06-05 04:23:31 +00:00
for batch_index , batch in enumerate ( inputs ) :
2024-05-29 00:29:54 +00:00
target = [ ]
for name , input in batch :
if name == " prom " :
target . append ( torch . full_like ( input [ . . . , 0 ] , self . ignore_index ) )
2024-06-05 04:23:31 +00:00
elif name in [ " text " , " quant_level " , " lang " , " tone " , " targ " ] :
2024-05-29 00:29:54 +00:00
target . append ( input )
target_list . append ( _join ( target , torch . tensor ( self . ignore_index , device = target [ - 1 ] . device ) ) )
2024-06-06 02:02:05 +00:00
batch_size = len ( target_list )
2024-05-29 00:29:54 +00:00
# modify only for the AR so it can properly behave like a transformer
2024-06-06 02:02:05 +00:00
for i in range ( batch_size ) :
2024-05-29 00:29:54 +00:00
if quant_levels is not None and quant_levels [ i ] > 0 :
continue
2024-04-17 02:04:48 +00:00
2024-05-29 00:29:54 +00:00
logits [ i ] = logits [ i ] [ . . . , : - 1 , : ] # shift the target so that token n...
target_list [ i ] = target_list [ i ] [ . . . , 1 : ] # predicts token n + 1
2024-04-17 02:04:48 +00:00
2024-06-06 00:50:06 +00:00
# see comments for the split-loss calc cross_entropy call
if False :
target = torch . cat ( target_list )
inputs = torch . cat ( logits )
self . loss = dict (
# "nll" was in the original implementation and should actually just be called something else
nll = F . cross_entropy ( inputs , target , ignore_index = self . ignore_index )
)
self . stats = dict (
acc = self . accuracy_metric ( inputs , target ) ,
# precision = self.precision_metric( inputs, target ),
)
else :
self . loss = dict (
2024-06-06 02:02:05 +00:00
nll = sum ( [ F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) * loss_factor for targets , inputs in zip ( target_list , logits ) ] ) / batch_size
2024-06-06 00:50:06 +00:00
)
self . stats = dict (
2024-06-06 02:02:05 +00:00
acc = sum ( [ self . accuracy_metric ( inputs , targets ) for targets , inputs in zip ( target_list , logits ) ] ) / batch_size
2024-06-06 00:50:06 +00:00
)
2024-04-17 02:04:48 +00:00
2024-05-29 00:29:54 +00:00
return
2024-05-19 16:23:56 +00:00
2024-06-06 00:50:06 +00:00
"""
# considerations:
# * split losses does not maintain the entire sequence
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
# + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
"""
2024-05-29 00:29:54 +00:00
self . loss = dict ( )
self . stats = dict ( acc = dict ( ) )
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
info = { }
2024-06-06 02:02:05 +00:00
batch_size = len ( inputs )
2024-05-29 00:29:54 +00:00
for i , batch in enumerate ( inputs ) :
quant_level = quant_levels [ i ] if quant_levels is not None else None
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
it = 0
for name , input in batch :
# do not use resp
if name == " resp " :
continue
# rename to resp
if name == " targ " :
name = " resp "
# select prom level
elif name == " prom " and quant_level is not None :
input = input [ : , quant_level ]
seq_len = input . shape [ 0 ]
2024-06-05 04:23:31 +00:00
2024-05-29 00:29:54 +00:00
logit = logits [ i ] [ it : it + seq_len ]
it + = seq_len + 1 # +1 to incorporate the separator
# for the AR, shift sequence so that it predicts the next token
2024-06-06 00:50:06 +00:00
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
2024-05-29 00:29:54 +00:00
if quant_level is None or quant_level == 0 :
logit = logit [ . . . , : - 1 , : ] # get all but the final logit
input = input [ . . . , 1 : ] # shift sequence to the right by one
if name not in info :
info [ name ] = {
" targets " : [ ] ,
" logits " : [ ] ,
}
2024-06-06 00:50:06 +00:00
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
info [ name ] [ " targets " ] . append ( input . long ( ) )
info [ name ] [ " logits " ] . append ( logit )
2024-05-29 00:29:54 +00:00
for name , batch in info . items ( ) :
loss_factor = self . loss_factor ( name )
2024-06-05 04:48:51 +00:00
if name not in [ " text " , " prom " , " resp " ] :
continue
2024-05-29 00:29:54 +00:00
if loss_factor == 0.0 :
continue
2024-05-27 13:43:00 +00:00
2024-06-06 00:50:06 +00:00
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
# to-do: set this to a var
if False :
targets = torch . cat ( batch [ " targets " ] ) . long ( )
inputs = torch . cat ( batch [ " logits " ] )
self . loss [ name ] = F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) * loss_factor
2024-06-05 04:48:51 +00:00
self . stats [ " acc " ] [ name ] = self . accuracy_metric ( inputs , targets )
2024-06-06 00:50:06 +00:00
# probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else :
2024-06-06 02:02:05 +00:00
self . loss [ name ] = sum ( [ F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) * loss_factor for targets , inputs in zip ( batch [ " targets " ] , batch [ " logits " ] ) ] ) / batch_size
self . stats [ " acc " ] [ name ] = sum ( [ self . accuracy_metric ( inputs , targets ) for targets , inputs in zip ( batch [ " targets " ] , batch [ " logits " ] ) ] ) / batch_size
2024-06-06 00:50:06 +00:00
# accuracy sometimes breaks for mamba
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
# to-do: compute loss per individual batch to scale per RVQ level
"""
rvq_loss_factor = self . loss_factor ( " quant " )
if isinstance ( rvq_loss_factor , list ) :
. . .
"""
2024-05-19 16:23:56 +00:00
2024-04-17 02:04:48 +00:00
def forward (
self ,
inputs : list ,
quant_levels : Tensor | None = None ,
state : dict | list | None = None ,
) :
x_list = self . inputs_to_embeddings ( inputs , quant_levels )
2024-04-16 00:54:32 +00:00
x , m = list_to_tensor ( x_list )
2024-04-17 02:04:48 +00:00
# yes, there's a better way.
training = False
2024-06-06 00:50:06 +00:00
for batch_index , batch in enumerate ( inputs ) :
for name , input in batch :
2024-04-17 02:04:48 +00:00
if name == " targ " :
training = True
device = x . device
batch_size = len ( x_list )
2024-04-16 00:54:32 +00:00
# pad our input and mask, but retain the original length by doing it after
if self . l_padding and x . shape [ 1 ] % self . l_padding != 0 :
# pad input
shape = list ( x . shape )
shape [ 1 ] = self . l_padding - shape [ 1 ] % self . l_padding
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
x = torch . cat ( [ x , padding ] , dim = 1 )
# pad mask
shape [ 2 ] = 1
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
m = torch . cat ( [ m , padding ] , dim = 1 )
x , state , aux_loss = self . _forward (
inputs = x ,
mask = m ,
state = state ,
)
2023-08-02 21:53:35 +00:00
# Remove padding
2023-09-09 01:30:54 +00:00
logits = [ hi [ : li ] for hi , li in zip ( x , map ( len , x_list ) ) ]
2024-04-14 18:12:50 +00:00
2023-08-02 21:53:35 +00:00
# compute loss if the target is given
2024-04-17 02:04:48 +00:00
if training :
2024-05-29 00:29:54 +00:00
self . calc_loss ( inputs = inputs , logits = logits , quant_levels = quant_levels )
2024-05-19 16:23:56 +00:00
# include any additional losses (for example: MoE router)
2023-12-23 01:27:36 +00:00
if aux_loss is not None :
2024-05-19 16:23:56 +00:00
self . loss [ " aux_loss " ] = aux_loss
2023-09-09 01:30:54 +00:00
2024-02-01 03:48:36 +00:00
return ( logits , state ) if state is not None else logits
2023-09-13 02:28:07 +00:00
def sample (
self ,
logits : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
quant_levels : Tensor | None = None ,
temperature : float = 1.0 ,
2023-10-10 22:02:33 +00:00
min_temperature : float = - 1.0 ,
2023-09-13 02:28:07 +00:00
top_k : int = - 100 ,
top_p : float = 1.0 ,
repetition_penalty : float = 1.0 ,
repetition_penalty_decay : float = 0.0 ,
length_penalty : float = 0.0 ,
2023-09-09 01:30:54 +00:00
2023-09-13 02:28:07 +00:00
beam_width : int = 0 ,
2023-09-18 23:55:41 +00:00
mirostat : list [ dict ] | None = None ,
2023-09-13 02:28:07 +00:00
) :
2023-10-10 22:02:33 +00:00
if min_temperature < 0 :
min_temperature = temperature
2024-06-06 00:50:06 +00:00
2023-09-09 01:30:54 +00:00
# (NAR) return the entire generated response
2024-06-06 00:50:06 +00:00
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
2023-09-09 01:30:54 +00:00
if quant_levels is not None :
logits = [ logit [ - l : ] for logit , l in zip ( logits , map ( len , resps_list ) ) ]
# (AR chunkwise) return the last chunkwise piece
2023-09-02 01:58:29 +00:00
elif self . causal and self . recurrent_chunk_size > 0 :
2023-09-09 01:30:54 +00:00
logits = [ logit [ - l : ] for logit , l in zip ( logits , self . recurrent_chunk_size ) ]
# (AR) return just the last code
2024-06-06 00:50:06 +00:00
# Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously)
2023-08-02 21:53:35 +00:00
else :
2023-09-09 01:30:54 +00:00
logits = [ logit [ - 1 : ] for logit in logits ]
2023-10-09 19:46:17 +00:00
devices = [ logit . device for logit in logits ]
2023-10-13 03:21:43 +00:00
logits = [ logit . to ( device = " cpu " , dtype = logit . dtype if logit . dtype != torch . float16 else torch . float32 ) for logit in logits ]
2023-10-09 19:46:17 +00:00
2023-09-09 01:30:54 +00:00
# perform repetition penalizing
2023-09-13 02:28:07 +00:00
logits = [ reptition_penalize ( logit , previous = resps [ : , - 1 ] , factor = repetition_penalty , decay = repetition_penalty_decay ) for logit , resps in zip ( logits , resps_list ) ]
2023-09-09 01:30:54 +00:00
2023-09-09 01:43:36 +00:00
# (AR) perform length penalizing
2023-09-09 01:30:54 +00:00
if quant_levels is None and self . causal :
2023-09-13 02:28:07 +00:00
logits = [ length_penalize ( logit , length = l + 1 , factor = length_penalty , token = self . stop_token ) for logit , l in zip ( logits , map ( len , resps_list ) ) ]
2023-09-09 01:30:54 +00:00
# perform top_k/top_p filtering of our logits
2023-09-13 02:28:07 +00:00
if top_k > 0 or top_p < 1.0 :
logits = [ top_k_top_p_filtering ( logit , top_k = top_k , top_p = top_p ) for logit in logits ]
2023-10-10 22:02:33 +00:00
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
2024-01-27 01:41:12 +00:00
# epsilon float comparison because I don't trust Python
2023-10-10 22:02:33 +00:00
if abs ( temperature - min_temperature ) > = 0.001 :
logits = [ dynamic_temperature ( logit , temperature = temperature , min_temperature = min_temperature ) for logit in logits ]
2023-10-09 18:01:40 +00:00
else :
logits = [ logit / temperature for logit in logits ]
2023-09-18 23:55:41 +00:00
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None :
# mirostat sampling
return [ mirostat_sample ( logit , state = state ) for logit , state in zip ( logits , mirostat ) ]
2023-09-13 02:28:07 +00:00
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
2023-09-13 18:19:11 +00:00
# returns the logit scores as well to be P-concatted with the previous scores
2023-09-13 02:28:07 +00:00
# to-do: not naively implement beam searching
if beam_width > 1 :
2023-09-13 18:19:11 +00:00
candidates = top_k_logits_list ( logits , beam_width )
2023-10-11 17:25:31 +00:00
res = [ torch . tensor ( token , dtype = torch . int16 ) . unsqueeze ( dim = - 1 ) for batch , token in candidates ]
scores = [ logits [ batch ] . flatten ( ) [ token ] for batch , token in candidates ]
2023-09-13 18:19:11 +00:00
return res , scores
2023-09-13 02:28:07 +00:00
2023-09-09 01:30:54 +00:00
# and sample
2024-05-04 17:05:41 +00:00
return [ Categorical ( logits = logit ) . sample ( ) for logit in logits ]