2023-08-02 21:53:35 +00:00
import math
import torch
import torch . nn . functional as F
import traceback
2023-09-13 02:28:07 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-08-02 21:53:35 +00:00
2024-05-04 18:07:45 +00:00
from typing import Literal , overload , Optional , Tuple
2023-08-02 21:53:35 +00:00
from functools import partial
from einops import rearrange
from torch import Tensor , einsum , nn
2024-04-17 02:04:48 +00:00
from torch . nn import Embedding
2023-08-02 21:53:35 +00:00
from torch . distributions import Categorical
from torch . nn . utils . rnn import pad_sequence
from torch . utils . checkpoint import checkpoint
from torchmetrics . classification import BinaryAccuracy , MulticlassAccuracy , MulticlassPrecision
2024-05-10 01:28:20 +00:00
from . . utils import wrapper as ml
2023-10-11 17:25:31 +00:00
from . . samplers import reptition_penalize , length_penalize , top_k_top_p_filtering , dynamic_temperature , top_k_logits_list , mirostat_sample
2023-08-02 21:53:35 +00:00
2023-12-23 01:27:36 +00:00
try :
from . transformer import SinusoidalEmbedding , Block as TransformerBlock
except Exception as e :
print ( " Error importing `transformer` arch: " , e )
pass
try :
2024-04-16 15:02:31 +00:00
#from .retnet import RetNetDecoder, RetNetConfig
from . retnet_ts import RetNetDecoder , RetNetConfig
2023-12-23 01:27:36 +00:00
except Exception as e :
print ( " Error importing `retnet` arch: " , e )
pass
2024-04-09 01:14:51 +00:00
from . retnet_hf import RetNetDecoder as RetNetDecoder_HF , RetNetConfig as RetNetConfig_HF
"""
try :
except Exception as e :
print ( " Error importing `retnet-hf` arch: " , e )
pass
"""
2023-12-23 01:27:36 +00:00
try :
from transformers import LlamaModel , LlamaConfig
except Exception as e :
print ( " Error importing `llama` arch: " , e )
pass
2024-02-01 03:48:36 +00:00
try :
from transformers import MistralModel , MistralConfig
except Exception as e :
print ( " Error importing `mistral` arch: " , e )
pass
2024-03-01 02:29:17 +00:00
try :
2024-04-05 00:11:49 +00:00
from bitnet . bit_transformer import Transformer as BitNetTransformerBlock , RMSNorm as BitNetRMSNorm
2024-05-28 00:47:58 +00:00
# re-enable logging because zetascale fucking sucks
import logging
logging . getLogger ( ) . setLevel ( logging . DEBUG )
2024-05-12 12:52:54 +00:00
# override for wrapping checkpointing
def BitNetTransformerBlock_forward ( self , x : Tensor , * args , * * kwargs ) - > Tensor :
skip = x
for attn , ffn in zip ( self . layers , self . ffn_layers ) :
2024-06-04 02:28:49 +00:00
if x . requires_grad and self . gradient_checkpointing :
2024-05-12 12:52:54 +00:00
x , _ = checkpoint ( attn , x , x , x , is_causal = True , * args , * * kwargs , use_reentrant = False )
else :
x , _ = attn ( x , x , x , is_causal = True , * args , * * kwargs )
x = x + skip
x = ffn ( x ) + x
return x
BitNetTransformerBlock . forward = BitNetTransformerBlock_forward
2024-04-09 01:14:51 +00:00
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
2024-04-05 00:11:49 +00:00
class BitNetTransformer ( nn . Module ) :
def __init__ (
self ,
dim : int ,
depth : int ,
num_tokens : int ,
heads = 8 ,
ff_mult = 4 ,
2024-06-04 02:28:49 +00:00
gradient_checkpointing = True
2024-04-05 00:11:49 +00:00
) :
super ( ) . __init__ ( )
self . transformer = BitNetTransformerBlock ( dim = dim , depth = depth , heads = heads , ff_mult = ff_mult )
self . norm = BitNetRMSNorm ( dim )
2024-06-04 02:28:49 +00:00
self . transformer . gradient_checkpointing = gradient_checkpointing
2024-03-01 02:29:17 +00:00
2024-04-05 00:11:49 +00:00
def forward ( self , x ) :
x = self . transformer ( x )
return self . norm ( x )
"""
from bitnet import BitNetTransformer
2024-03-01 02:29:17 +00:00
def NoEmbedding_BitNetTransformer_Forward ( self , x ) :
x = self . transformer ( x )
return self . to_logits [ 0 ] ( x )
BitNetTransformer . forward = NoEmbedding_BitNetTransformer_Forward
2024-04-05 00:11:49 +00:00
"""
2024-03-01 02:29:17 +00:00
except Exception as e :
print ( " Error importing `bitnet` arch: " , e )
pass
2023-12-23 01:27:36 +00:00
try :
from transformers import MixtralModel , MixtralConfig
2024-01-27 01:41:12 +00:00
from transformers . models . mixtral . modeling_mixtral import load_balancing_loss_func , MixtralSparseMoeBlock
# This is required because batch sizes > 1 throws errors
def Fixed_MixtralSparseMoeBlock_forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
""" """
batch_size , sequence_length , hidden_dim = hidden_states . shape
hidden_states = hidden_states . reshape ( - 1 , hidden_dim ) # was view()
# router_logits: (batch * sequence_length, n_experts)
router_logits = self . gate ( hidden_states )
routing_weights = F . softmax ( router_logits , dim = 1 , dtype = torch . float )
routing_weights , selected_experts = torch . topk ( routing_weights , self . top_k , dim = - 1 )
routing_weights / = routing_weights . sum ( dim = - 1 , keepdim = True )
# we cast back to the input dtype
routing_weights = routing_weights . to ( hidden_states . dtype )
final_hidden_states = torch . zeros (
( batch_size * sequence_length , hidden_dim ) , dtype = hidden_states . dtype , device = hidden_states . device
)
expert_mask = torch . nn . functional . one_hot ( selected_experts , num_classes = self . num_experts ) . permute ( 2 , 1 , 0 )
for expert_idx in range ( self . num_experts ) :
expert_layer = self . experts [ expert_idx ]
idx , top_x = torch . where ( expert_mask [ expert_idx ] )
if top_x . shape [ 0 ] == 0 :
continue
top_x_list = top_x . tolist ( )
idx_list = idx . tolist ( )
current_state = hidden_states [ None , top_x_list ] . reshape ( - 1 , hidden_dim )
current_hidden_states = expert_layer ( current_state ) * routing_weights [ top_x_list , idx_list , None ]
final_hidden_states . index_add_ ( 0 , top_x , current_hidden_states . to ( hidden_states . dtype ) )
final_hidden_states = final_hidden_states . reshape ( batch_size , sequence_length , hidden_dim )
return final_hidden_states , router_logits
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock . forward
MixtralSparseMoeBlock . forward = Fixed_MixtralSparseMoeBlock_forward
2023-12-23 01:27:36 +00:00
except Exception as e :
print ( " Error importing `mixtral` arch: " , e )
2024-05-04 18:07:45 +00:00
2024-05-11 22:14:05 +00:00
AVAILABLE_ATTENTIONS = [ " mem_efficient " , " math " ]
try :
from xformers . ops import LowerTriangularMask
from xformers . ops . fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS . append ( " xformers " )
except Exception as e :
print ( " Error while importing `xformers` " , e )
try :
from transformers . utils import is_flash_attn_2_available
if is_flash_attn_2_available ( ) :
AVAILABLE_ATTENTIONS . append ( " flash " )
except Exception as e :
raise e
2024-05-04 18:07:45 +00:00
try :
from transformers . cache_utils import Cache
from transformers . models . llama . modeling_llama import LlamaAttention , apply_rotary_pos_emb
2024-05-11 22:14:05 +00:00
class Llama_Attention ( LlamaAttention ) :
def __init__ ( self , * args , * * kwargs ) :
if ' mode ' in kwargs :
self . mode = kwargs [ ' mode ' ]
kwargs . pop ( " mode " )
else :
self . mode = " math "
super ( ) . __init__ ( * args , * * kwargs )
2024-05-04 18:07:45 +00:00
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_value : Optional [ Cache ] = None ,
output_attentions : bool = False ,
use_cache : bool = False ,
cache_position : Optional [ torch . LongTensor ] = None ,
* * kwargs ,
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
bsz , q_len , _ = hidden_states . size ( )
query_states = self . q_proj ( hidden_states )
key_states = self . k_proj ( hidden_states )
value_states = self . v_proj ( hidden_states )
query_states = query_states . view ( bsz , q_len , self . num_heads , self . head_dim ) . transpose ( 1 , 2 )
key_states = key_states . view ( bsz , q_len , self . num_key_value_heads , self . head_dim ) . transpose ( 1 , 2 )
value_states = value_states . view ( bsz , q_len , self . num_key_value_heads , self . head_dim ) . transpose ( 1 , 2 )
cos , sin = self . rotary_emb ( value_states , position_ids )
query_states , key_states = apply_rotary_pos_emb ( query_states , key_states , cos , sin , position_ids )
past_key_value = getattr ( self , " past_key_value " , past_key_value )
if past_key_value is not None :
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { " sin " : sin , " cos " : cos , " cache_position " : cache_position }
key_states , value_states = past_key_value . update ( key_states , value_states , self . layer_idx , cache_kwargs )
query_states = query_states . transpose ( 1 , 2 )
key_states = key_states . transpose ( 1 , 2 )
value_states = value_states . transpose ( 1 , 2 )
dropout_rate = self . attention_dropout if self . training else 0.0
2024-05-11 22:14:05 +00:00
if self . mode == " xformers " :
if attention_mask is None or attention_mask [ 0 , 0 , 0 , 1 ] == 0 :
2024-05-11 22:34:45 +00:00
attn_output = memory_efficient_attention ( query_states , key_states , value_states , attn_bias = None , p = dropout_rate )
2024-05-11 22:14:05 +00:00
else :
2024-05-11 22:34:45 +00:00
attn_output = memory_efficient_attention ( query_states , key_states , value_states , attn_bias = LowerTriangularMask ( ) , p = dropout_rate )
2024-05-04 18:07:45 +00:00
else :
2024-05-11 22:29:01 +00:00
#torch.nn.attention.sdpa_kernel
2024-05-11 22:14:05 +00:00
with torch . backends . cuda . sdp_kernel ( enable_flash = self . mode == " flash " , enable_math = self . mode == " math " , enable_mem_efficient = self . mode == " mem_efficient " ) :
2024-05-11 22:34:45 +00:00
attn_output = torch . nn . functional . scaled_dot_product_attention ( query_states , key_states , value_states , attn_mask = attention_mask , dropout_p = dropout_rate )
2024-05-04 18:07:45 +00:00
attn_weights = None
attn_output = attn_output . reshape ( bsz , q_len , self . hidden_size )
attn_output = self . o_proj ( attn_output )
return attn_output , attn_weights , past_key_value
except Exception as e :
2024-05-11 22:14:05 +00:00
print ( " Error creating modified `LLamaAttention`: " , e )
2024-05-04 18:07:45 +00:00
2023-08-02 21:53:35 +00:00
def _create_mask ( l , device ) :
""" 1 is valid region and 0 is invalid. """
seq = torch . arange ( max ( l ) , device = device ) . unsqueeze ( 0 ) # (1 t)
stop = torch . tensor ( l , device = device ) . unsqueeze ( 1 ) # (b 1)
return ( seq < stop ) . float ( ) # (b t)
def _join ( x : tuple [ Tensor ] , sep : Tensor ) :
"""
Args :
x : ( k t d )
sep : ( d )
"""
ret = x [ 0 ]
for i in range ( 1 , len ( x ) ) :
ret = torch . cat ( ( ret , sep [ None ] , x [ i ] ) , dim = 0 )
return ret
def list_to_tensor ( x_list : list [ Tensor ] , pattern = " t b c -> b t c " ) :
"""
Args :
x_list : [ ( t d ) ]
Returns :
x : ( ? ? ? )
m : ( ? ? ? ) , same as x
"""
l = list ( map ( len , x_list ) )
x = rearrange ( pad_sequence ( x_list ) , pattern )
m = _create_mask ( l , x_list [ 0 ] . device )
m = m . t ( ) . unsqueeze ( - 1 ) # (t b 1)
m = rearrange ( m , pattern )
m = m . to ( x )
return x , m
2023-09-09 01:30:54 +00:00
# automagically parses a batch-list and returns it as a list
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
class Embedding ( nn . Embedding ) :
def forward ( self , x_list : list [ Tensor ] ) - > list [ Tensor ] :
if len ( x_list ) == 0 :
return [ ]
return super ( ) . forward ( torch . cat ( x_list ) ) . split ( [ * map ( len , x_list ) ] )
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
2024-04-29 23:24:05 +00:00
# Deprecated implementation
2023-09-13 18:19:11 +00:00
class MultiEmbedding ( nn . Module ) :
2023-09-08 06:03:24 +00:00
def __init__ ( self , max_n_levels , n_tokens , token_dim , monolithic = False ) :
2023-09-16 05:26:13 +00:00
super ( ) . __init__ ( )
2023-09-08 20:36:26 +00:00
self . monolithic = monolithic
2023-08-02 21:53:35 +00:00
self . max_n_levels = max_n_levels
self . n_tokens = n_tokens
2023-09-08 20:36:26 +00:00
self . weight = nn . Parameter ( torch . randn ( max_n_levels , n_tokens , token_dim ) )
2023-08-02 21:53:35 +00:00
2023-09-07 22:08:38 +00:00
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
def forward ( self , x_list : list [ Tensor ] , quant_levels : Tensor | None = None ) - > list [ Tensor ] :
2023-08-02 21:53:35 +00:00
if len ( x_list ) == 0 :
return [ ]
2023-09-08 20:36:26 +00:00
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
2023-09-08 06:03:24 +00:00
if self . monolithic :
2023-09-08 20:36:26 +00:00
w = self . weight [ : 1 ] if quant_levels is None else self . weight [ 1 : ]
2023-09-08 06:03:24 +00:00
else :
w = self . weight
2023-08-02 21:53:35 +00:00
2023-09-11 19:13:42 +00:00
padded_x_list = [ ]
2023-08-02 21:53:35 +00:00
2023-09-08 20:36:26 +00:00
for i , xi in enumerate ( x_list ) :
2023-08-02 21:53:35 +00:00
xi = F . one_hot ( xi . to ( torch . int64 ) , num_classes = self . n_tokens ) # t l' k
2023-09-11 19:13:42 +00:00
wi = w . shape [ 0 ] - xi . shape [ 1 ]
xi = F . pad ( xi , ( 0 , 0 , 0 , wi ) ) # t l k
2023-08-02 21:53:35 +00:00
padded_x_list . append ( xi . to ( w ) )
2023-09-11 19:13:42 +00:00
x = torch . cat ( padded_x_list ) # n l k
x = einsum ( " l k d, n l k -> n d " , w , x )
2023-08-02 21:53:35 +00:00
x_list = x . split ( [ * map ( len , x_list ) ] )
2023-09-07 21:48:02 +00:00
2023-09-11 19:13:42 +00:00
return x_list
2023-09-07 21:48:02 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
2023-09-11 19:13:42 +00:00
class AudioEmbedding ( nn . Module ) :
2024-04-29 23:24:05 +00:00
def __init__ (
self ,
l_tokens : int , # list of number of tokens (needed because AR resps includes stop token)
token_dim : int , # dimensionality of the embedding
levels : int | None = None , # number of RVQ-bins (I don't remember the specifics)
sums : bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
) :
2023-09-07 14:14:03 +00:00
super ( ) . __init__ ( )
2024-04-29 23:24:05 +00:00
# array of embeddings
# proms are [0, prom_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
2023-09-21 00:10:59 +00:00
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
2024-04-29 23:24:05 +00:00
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
2024-01-25 18:18:48 +00:00
self . weight = nn . ParameterList ( [ nn . Parameter ( torch . Tensor ( [ 1 ] ) ) for i in range ( levels ) ] ) if levels is not None else None
2024-04-29 23:24:05 +00:00
#
self . sums = sums
2023-09-07 14:14:03 +00:00
2024-04-17 02:04:48 +00:00
def forward ( self , xi : Tensor , quant_levels : Tensor | None = None ) - > Tensor :
# prom
if quant_levels is None and xi . shape [ - 1 ] > 1 :
2024-04-29 23:24:05 +00:00
if self . sums :
x = sum ( [ self . embeddings [ k ] ( xi [ : , k ] ) * ( self . weight [ k ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
else :
k = 0 # only use the most significant RVQ bin level for the input prom
x = self . embeddings [ k ] ( xi [ : , k ] ) * ( self . weight [ k ] if self . weight is not None else 1 )
2024-04-17 02:04:48 +00:00
# AR resp
elif quant_levels is None or quant_levels == 0 :
2024-05-31 01:50:45 +00:00
x = self . embeddings [ 0 ] ( xi if len ( xi . shape ) == 1 else xi [ : , 0 ] )
2024-04-17 02:04:48 +00:00
# NAR resp
else :
2024-04-29 23:24:05 +00:00
if self . sums :
x = sum ( [ self . embeddings [ k + 1 ] ( xi [ : , k ] ) * ( self . weight [ k + 1 ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
else :
k = xi . shape [ - 1 ] - 1 # only use the previous RVQ bin level for the current resp embedding
x = self . embeddings [ k + 1 ] ( xi [ : , k ] ) * ( self . weight [ k + 1 ] if self . weight is not None else 1 )
2024-04-17 02:04:48 +00:00
return x
2023-09-07 21:48:02 +00:00
2023-08-02 21:53:35 +00:00
class Base ( nn . Module ) :
@property
def causal ( self ) - > bool :
raise NotImplementedError
@property
def arch_type ( self ) - > str :
raise NotImplementedError
@property
def norm_type ( self ) :
raise NotImplementedError
@property
def n_prom_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 20:06:33 +00:00
2023-09-06 23:58:35 +00:00
@property
def n_resp_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 20:06:33 +00:00
@property
def n_max_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 01:58:07 +00:00
2023-09-21 00:10:59 +00:00
@property
def n_langs ( self ) - > int :
raise NotImplementedError
2024-04-16 00:54:32 +00:00
2023-08-19 01:58:07 +00:00
@property
def n_tasks ( self ) - > int :
raise NotImplementedError
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
@property
def n_tones ( self ) - > int :
raise NotImplementedError
2023-09-02 01:58:29 +00:00
@property
def recurrent_chunk_size ( self ) - > int :
raise NotImplementedError
2023-09-21 00:10:59 +00:00
@property
def rotary_embedding_base ( self ) - > float :
return 10000
2023-09-04 03:46:08 +00:00
@property
def interleave ( self ) - > bool :
return False
2023-09-02 01:58:29 +00:00
2023-09-07 00:33:39 +00:00
@property
2023-09-07 21:48:02 +00:00
def monolithic ( self ) - > bool :
2023-09-07 00:33:39 +00:00
return False
2023-09-07 14:14:03 +00:00
@property
2023-09-11 19:13:42 +00:00
def version ( self ) - > int :
return 1
2023-09-07 14:14:03 +00:00
2023-09-06 23:58:35 +00:00
@property
def stop_token ( self ) :
if not self . causal :
raise ValueError ( " Not using stop token! " )
return self . n_tokens
@property
def ignore_index ( self ) :
return - 100
2024-05-19 16:23:56 +00:00
def loss_factor ( self , k ) :
2024-06-04 02:28:49 +00:00
if self . hyper_config is None :
2024-05-19 16:23:56 +00:00
return 1.0
2024-06-04 02:28:49 +00:00
return self . hyper_config . loss_factors [ k ] if k in self . hyper_config . loss_factors else 1.0
2024-05-19 16:23:56 +00:00
2023-08-02 21:53:35 +00:00
def __init__ (
self ,
2023-08-19 01:58:07 +00:00
n_tokens : int = 1024 ,
2023-08-02 21:53:35 +00:00
d_model : int = 512 ,
n_heads : int = 8 ,
n_layers : int = 12 ,
p_dropout : float = 0.1 ,
2023-09-07 00:33:39 +00:00
2024-04-09 01:14:51 +00:00
n_experts : int = 1 ,
l_padding : int = 0 ,
2023-12-21 00:45:58 +00:00
2024-02-01 03:48:36 +00:00
training = True ,
2023-09-04 03:46:08 +00:00
config = None ,
2023-08-02 21:53:35 +00:00
) :
super ( ) . __init__ ( )
2024-02-01 03:48:36 +00:00
self . training = training
2024-06-04 02:28:49 +00:00
self . hyper_config = config
self . gradient_checkpointing = self . hyper_config . gradient_checkpointing if self . hyper_config is not None else True
2023-09-04 03:46:08 +00:00
2023-08-02 21:53:35 +00:00
self . n_tokens = n_tokens
self . d_model = d_model
self . n_heads = n_heads
self . n_layers = n_layers
2023-12-21 00:45:58 +00:00
self . n_experts = n_experts
2024-04-09 01:14:51 +00:00
self . l_padding = l_padding
2023-08-02 21:53:35 +00:00
# +1 to include the stop token
2023-09-21 00:10:59 +00:00
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
n_prom_tokens = n_tokens
2023-09-06 23:58:35 +00:00
n_resp_tokens = n_tokens + ( 1 if self . causal else 0 ) # AR requires a stop token to... know when to stop
2023-08-02 21:53:35 +00:00
self . text_emb = Embedding ( n_tokens , d_model )
2023-10-22 14:01:47 +00:00
self . langs_emb = None
2024-04-16 00:54:32 +00:00
self . tones_emb = None
2023-10-22 14:01:47 +00:00
self . tasks_emb = None
2023-09-07 00:33:39 +00:00
2023-09-11 19:13:42 +00:00
if self . version == 1 : # legacy
2023-09-21 00:10:59 +00:00
n_prom_tokens + = ( self . n_tasks - 1 ) # old models have the task tokens in the prom
2023-09-11 19:13:42 +00:00
self . proms_emb = MultiEmbedding ( self . n_prom_levels , n_prom_tokens , d_model )
self . resps_emb = MultiEmbedding ( self . n_resp_levels , n_resp_tokens , d_model , monolithic = self . monolithic )
else :
2023-09-21 00:10:59 +00:00
# [1024] * 8
2024-04-29 23:24:05 +00:00
self . proms_emb = AudioEmbedding (
[ n_prom_tokens ] * self . n_prom_levels , d_model ,
levels = self . n_prom_levels if self . version > 3 else None ,
2024-06-04 02:28:49 +00:00
sums = self . hyper_config . audio_embedding_sums if self . hyper_config is not None else True ,
2024-04-29 23:24:05 +00:00
)
2023-09-21 00:10:59 +00:00
# [1025] + [1024] * 8
2024-04-29 23:24:05 +00:00
self . resps_emb = AudioEmbedding (
[ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 ) , d_model ,
levels = self . n_resp_levels if self . version > 3 else None ,
2024-06-04 02:28:49 +00:00
sums = self . hyper_config . audio_embedding_sums if self . hyper_config is not None else True
2024-04-29 23:24:05 +00:00
)
2023-09-21 00:10:59 +00:00
2023-10-22 14:01:47 +00:00
2023-10-12 01:38:40 +00:00
if self . version > = 3 :
2023-10-22 14:01:47 +00:00
self . langs_emb = Embedding ( self . n_langs , d_model ) if self . n_langs > 0 else None
self . tasks_emb = Embedding ( self . n_tasks , d_model ) if self . n_tasks > 0 else None
2024-04-16 00:54:32 +00:00
if self . version > = 4 :
self . tones_emb = Embedding ( self . n_tones , d_model ) if self . n_tones > 0 else None
2023-08-02 21:53:35 +00:00
self . sep = nn . Parameter ( torch . randn ( d_model ) )
2023-08-27 00:53:23 +00:00
2024-05-10 01:28:20 +00:00
# ick, there has to be a better way
2024-06-04 02:28:49 +00:00
hf_attention = self . hyper_config . attention if self . hyper_config is not None else None
2024-05-11 22:14:05 +00:00
2024-06-04 02:28:49 +00:00
if self . hyper_config . attention == " auto " :
2024-05-11 22:14:05 +00:00
if " flash " in AVAILABLE_ATTENTIONS :
2024-06-04 02:28:49 +00:00
self . hyper_config . attention = " flash "
2024-05-11 22:14:05 +00:00
elif " xformers " in AVAILABLE_ATTENTIONS :
2024-06-04 02:28:49 +00:00
self . hyper_config . attention = " xformers "
2024-05-11 22:14:05 +00:00
else :
2024-06-04 02:28:49 +00:00
self . hyper_config . attention = " mem_efficient "
2024-05-11 22:14:05 +00:00
2024-06-04 02:28:49 +00:00
if self . hyper_config . attention in [ " xformers " , " mem_efficient " , " math " , " flash " ] :
2024-05-11 22:14:05 +00:00
hf_attention = None
2024-06-04 02:28:49 +00:00
if self . hyper_config . attention not in AVAILABLE_ATTENTIONS :
raise ValueError ( f " Requesting attention ` { self . hyper_config . attention } ` but is not available. Currently available: { AVAILABLE_ATTENTIONS } " )
2024-05-10 01:28:20 +00:00
2023-08-02 21:53:35 +00:00
if self . arch_type == " transformer " :
self . sin_emb = SinusoidalEmbedding ( d_model )
self . blocks = nn . ModuleList ( [ TransformerBlock (
d_model = d_model ,
n_heads = n_heads ,
2024-02-01 03:48:36 +00:00
p_dropout = p_dropout if training else 0.0 ,
2023-08-19 01:58:07 +00:00
causal = self . causal ,
2023-08-02 21:53:35 +00:00
norm_type = self . norm_type ,
n_levels = self . n_resp_levels ,
) for _ in range ( n_layers ) ] )
2024-05-10 01:28:20 +00:00
elif self . arch_type in [ " mistral " , " mixtral " ] :
2024-02-01 03:48:36 +00:00
if n_experts < = 1 :
self . model = MistralModel ( MistralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
num_key_value_heads = self . hyper_config . kv_heads if self . hyper_config . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
num_key_value_heads = self . hyper_config . kv_heads if self . hyper_config . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
2024-05-10 04:25:44 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-05-10 04:25:44 +00:00
2024-05-12 13:22:39 +00:00
#if training:
# self.model.training = True
2023-12-23 01:27:36 +00:00
elif self . arch_type == " llama " :
if n_experts < = 1 :
self . model = LlamaModel ( LlamaConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
2024-05-10 04:15:52 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-05-10 04:15:52 +00:00
2024-05-12 13:22:39 +00:00
#if training:
# self.model.training = True
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-12-26 03:20:32 +00:00
kwargs = dict (
2023-12-23 01:27:36 +00:00
vocab_size = n_resp_tokens ,
2023-08-02 21:53:35 +00:00
decoder_embed_dim = d_model ,
2023-10-05 21:39:46 +00:00
decoder_value_embed_dim = d_model * 2 ,
2023-08-02 21:53:35 +00:00
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
2024-02-01 03:48:36 +00:00
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2023-10-05 21:39:46 +00:00
activation_fn = " gelu " ,
2024-05-10 01:28:20 +00:00
use_layernorm = self . version < 3 ,
use_biases = self . version < 3 ,
use_glu = self . version > = 3 ,
2023-08-02 21:53:35 +00:00
2023-09-02 01:58:29 +00:00
chunkwise_recurrent = self . causal and self . recurrent_chunk_size > 0 ,
recurrent_chunkwise_size = self . recurrent_chunk_size if self . causal else 0 ,
2023-08-02 21:53:35 +00:00
no_output_layer = True ,
decoder_normalize_before = True ,
2023-09-21 00:10:59 +00:00
rotary_embedding_base = self . rotary_embedding_base , # 10000
2023-12-26 03:20:32 +00:00
)
if n_experts > 1 :
kwargs . update ( dict (
use_xmoe = True ,
moe_freq = 1 ,
moe_expert_count = n_experts ,
moe_gating_use_fp32 = False ,
) )
2023-12-21 00:45:58 +00:00
2023-12-26 03:20:32 +00:00
self . model = RetNetDecoder ( RetNetConfig ( * * kwargs ) )
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
kwargs = dict (
vocab_size = n_resp_tokens ,
decoder_embed_dim = d_model ,
decoder_value_embed_dim = d_model * 2 ,
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2024-04-09 01:14:51 +00:00
activation_fn = " gelu " ,
use_glu = False , # self.version >= 3,
recurrent_chunk_size = self . recurrent_chunk_size if self . causal else 0 ,
decoder_normalize_before = True ,
deepnorm = False ,
subln = True ,
)
self . model = RetNetDecoder_HF ( RetNetConfig_HF ( * * kwargs ) )
2024-05-12 13:22:39 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-12 13:22:39 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
self . model = BitNetTransformer (
num_tokens = n_resp_tokens ,
dim = d_model ,
depth = n_layers ,
heads = n_heads ,
ff_mult = 4 ,
2024-06-04 02:28:49 +00:00
gradient_checkpointing = self . gradient_checkpointing ,
2024-03-01 02:29:17 +00:00
)
else :
raise RuntimeError ( f ' Unknown arch specified: { self . arch_type } ' )
2023-08-04 01:26:36 +00:00
2024-06-04 02:28:49 +00:00
if self . hyper_config . attention in [ " xformers " , " auto " , " mem_efficient " , " math " , " flash " ] :
self . model = ml . replace_attention ( self . model , klass = Llama_Attention , target = LlamaAttention , mode = self . hyper_config . attention )
2024-05-10 01:28:20 +00:00
2023-08-02 21:53:35 +00:00
self . classifier = nn . Linear ( d_model , n_resp_tokens )
self . accuracy_metric = MulticlassAccuracy (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
self . precision_metric = MulticlassPrecision (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
2024-04-16 00:54:32 +00:00
def _forward (
2023-08-02 21:53:35 +00:00
self ,
2024-04-16 00:54:32 +00:00
inputs ,
mask = None ,
state = None ,
2023-08-02 21:53:35 +00:00
) :
2024-04-16 00:54:32 +00:00
x = inputs
m = mask . squeeze ( - 1 ) . int ( )
2023-12-23 01:27:36 +00:00
aux_loss = None
2024-04-14 18:12:50 +00:00
"""
# Broken
if state is not None and ( self . arch_type == " retnet " or self . arch_type == " retnet-hf " ) :
2023-09-02 01:58:29 +00:00
# prefill
if len ( state ) == 0 :
2023-09-05 20:38:21 +00:00
prefill_size = x . shape [ 1 ]
# run the initial prompt to fill the KV cache
2024-04-09 01:14:51 +00:00
if self . arch_type == " retnet " :
for n in range ( prefill_size ) :
xi = x [ : , n , : ] . unsqueeze ( 1 )
self . model ( xi , incremental_state = state , token_embeddings = xi , features_only = True )
elif self . arch_type == " retnet-hf " :
2024-04-14 18:12:50 +00:00
state = None
2024-04-09 01:14:51 +00:00
for n in range ( prefill_size ) :
xi = x [ : , n , : ] . unsqueeze ( 1 )
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2024-04-14 18:12:50 +00:00
inputs_embeds = xi ,
past_key_values = state ,
use_cache = True ,
forward_impl = ' recurrent ' ,
2024-04-09 01:14:51 +00:00
# return_dict=True,
)
out = self . model ( * * kwargs )
2024-04-14 18:12:50 +00:00
state = out . past_key_values
2023-09-02 01:58:29 +00:00
# grab last token(s)
x = x [ : , - 1 , : ] . unsqueeze ( 1 )
2024-04-14 18:12:50 +00:00
"""
2023-12-23 01:27:36 +00:00
# HF transformer derived model
2024-04-16 00:54:32 +00:00
if self . arch_type in [ " llama " , " mistral " , " mixtral " ] :
2023-12-23 01:27:36 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2023-12-23 01:27:36 +00:00
inputs_embeds = x ,
2024-02-01 03:48:36 +00:00
past_key_values = state ,
2024-04-14 18:12:50 +00:00
use_cache = True ,
2024-02-01 03:48:36 +00:00
# return_dict=True,
2023-12-23 01:27:36 +00:00
)
2024-02-01 03:48:36 +00:00
if self . n_experts > 1 and targ_list is not None :
2023-12-23 01:27:36 +00:00
kwargs [ " output_router_logits " ] = True
2023-09-02 01:58:29 +00:00
2023-12-23 01:27:36 +00:00
t = self . model ( * * kwargs )
2024-02-01 03:48:36 +00:00
2023-12-23 01:27:36 +00:00
x = t [ 0 ]
2024-02-01 03:48:36 +00:00
if state is not None :
state = t [ 1 ]
if self . n_experts > 1 and targ_list is not None :
2023-12-23 01:27:36 +00:00
router_logits = t [ - 1 ]
aux_loss = self . model . config . router_aux_loss_coef * load_balancing_loss_func ( router_logits , self . model . config . num_local_experts , self . model . config . num_experts_per_tok )
elif self . arch_type == " transformer " :
2023-09-12 21:04:45 +00:00
# ensures we specify a quant_level for the transformer implementation's AdaLN
2023-09-07 14:14:03 +00:00
l = torch . zeros ( ( batch_size , ) , dtype = torch . int32 ) if quant_levels is None else quant_levels
l = l . to ( device )
2023-09-12 21:04:45 +00:00
# inject position information
x = self . sin_emb . add_pe ( x )
# pass our inputs through the transformer
2023-08-02 21:53:35 +00:00
for block in self . blocks :
2024-04-16 00:54:32 +00:00
x = block ( x , m , l )
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-09-12 21:04:45 +00:00
# pass our inputs through the RetNet
2023-12-23 01:27:36 +00:00
x , _ = self . model ( x , incremental_state = state , token_embeddings = x , features_only = True )
2023-12-26 03:20:32 +00:00
if _ is not None and " l_aux " in _ and self . n_experts > 1 :
2023-12-23 01:27:36 +00:00
aux_loss = torch . sum ( torch . stack ( [ t for t in _ [ " l_aux " ] if t is not None ] ) ) * 0.001
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
2024-04-14 18:12:50 +00:00
first = state is None or len ( state ) == 0
2024-04-09 01:14:51 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2024-04-14 18:12:50 +00:00
inputs_embeds = x if first else x [ : , - 1 , : ] . unsqueeze ( 1 ) ,
past_key_values = None if first else state ,
use_cache = True ,
forward_impl = ' parallel ' if first else ' recurrent ' ,
return_dict = True ,
2024-04-09 01:14:51 +00:00
)
2024-04-14 18:12:50 +00:00
out = self . model ( * * kwargs )
x = out . last_hidden_state
2024-04-09 01:14:51 +00:00
if state is not None :
2024-04-14 18:12:50 +00:00
state = out . past_key_values
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
x = self . model ( x )
2024-04-14 18:12:50 +00:00
2023-09-12 21:04:45 +00:00
# output projection layer with masking
2024-04-16 00:54:32 +00:00
x = self . classifier ( x ) * mask
2024-04-14 18:12:50 +00:00
2024-04-16 00:54:32 +00:00
return x , state , aux_loss
2024-04-17 02:04:48 +00:00
def inputs (
2024-04-16 00:54:32 +00:00
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
targ_list : list [ Tensor ] | None = None ,
2024-04-17 02:04:48 +00:00
2024-04-16 00:54:32 +00:00
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
2024-04-17 02:04:48 +00:00
inputs = [ [ ] for _ in range ( batch_size ) ]
for i in range ( batch_size ) :
if text_list is not None :
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
if proms_list is not None :
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
if resps_list is not None :
inputs [ i ] . append ( ( " resp " , resps_list [ i ] ) )
if targ_list is not None :
inputs [ i ] . append ( ( " targ " , targ_list [ i ] ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return inputs
def inputs_to_embeddings (
self ,
inputs : list ,
quant_levels : Tensor | None = None
) :
x_list = [ ]
2024-05-29 00:29:54 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-04-17 02:04:48 +00:00
batch = [ ]
2024-05-29 00:29:54 +00:00
quant_level = quant_levels [ batch_index ] if quant_levels is not None else None
for name , input in batch_input :
2024-04-17 02:04:48 +00:00
embedding = None
if name == " text " :
embedding = self . text_emb ( input )
elif name == " lang " :
embedding = self . langs_emb ( input )
elif name == " prom " :
embedding = self . proms_emb ( input )
elif name == " tone " :
embedding = self . tones_emb ( input )
elif name == " resp " :
2024-05-29 00:29:54 +00:00
embedding = self . resps_emb ( input , quant_level )
2024-04-17 02:04:48 +00:00
else :
continue
batch . append ( embedding )
x_list . append ( _join ( batch , self . sep ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return x_list
2024-05-29 00:29:54 +00:00
def calc_loss (
2024-04-17 02:04:48 +00:00
self ,
inputs : list ,
2024-05-29 00:29:54 +00:00
logits ,
quant_levels : Tensor | None = None ,
2024-04-17 02:04:48 +00:00
) :
2024-05-29 00:29:54 +00:00
# old, "naive" way, no loss factoring
2024-06-04 02:28:49 +00:00
if not self . hyper_config . loss_factors :
2024-05-29 00:29:54 +00:00
target_list = [ ]
for batch in inputs :
target = [ ]
for name , input in batch :
if name == " prom " :
target . append ( torch . full_like ( input [ . . . , 0 ] , self . ignore_index ) )
elif name in [ " text " , " lang " , " tone " , " targ " ] :
target . append ( input )
target_list . append ( _join ( target , torch . tensor ( self . ignore_index , device = target [ - 1 ] . device ) ) )
# modify only for the AR so it can properly behave like a transformer
for i in range ( len ( target_list ) ) :
if quant_levels is not None and quant_levels [ i ] > 0 :
continue
2024-04-17 02:04:48 +00:00
2024-05-29 00:29:54 +00:00
logits [ i ] = logits [ i ] [ . . . , : - 1 , : ] # shift the target so that token n...
target_list [ i ] = target_list [ i ] [ . . . , 1 : ] # predicts token n + 1
2024-04-17 02:04:48 +00:00
2024-05-29 00:29:54 +00:00
target = torch . cat ( target_list )
inputs = torch . cat ( logits )
2024-04-17 02:04:48 +00:00
2024-05-29 00:29:54 +00:00
self . loss = dict (
# "nll" was in the original implementation and should actually just be called something else
nll = F . cross_entropy ( inputs , target , ignore_index = self . ignore_index )
)
self . stats = dict (
acc = self . accuracy_metric ( inputs , target ) ,
# precision = self.precision_metric( inputs, target ),
)
return
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
self . loss = dict ( )
self . stats = dict ( acc = dict ( ) )
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
info = { }
for i , batch in enumerate ( inputs ) :
quant_level = quant_levels [ i ] if quant_levels is not None else None
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
it = 0
for name , input in batch :
# do not use resp
if name == " resp " :
continue
# rename to resp
if name == " targ " :
name = " resp "
# select prom level
elif name == " prom " and quant_level is not None :
input = input [ : , quant_level ]
seq_len = input . shape [ 0 ]
logit = logits [ i ] [ it : it + seq_len ]
it + = seq_len + 1 # +1 to incorporate the separator
# for the AR, shift sequence so that it predicts the next token
if quant_level is None or quant_level == 0 :
logit = logit [ . . . , : - 1 , : ] # get all but the final logit
input = input [ . . . , 1 : ] # shift sequence to the right by one
if name not in info :
info [ name ] = {
" targets " : [ ] ,
" logits " : [ ] ,
}
2024-06-04 01:26:27 +00:00
info [ name ] [ " targets " ] . append ( input . contiguous ( ) )
info [ name ] [ " logits " ] . append ( logit . contiguous ( ) )
2024-05-29 00:29:54 +00:00
for name , batch in info . items ( ) :
loss_factor = self . loss_factor ( name )
if loss_factor == 0.0 :
continue
2024-05-27 13:43:00 +00:00
2024-05-29 00:29:54 +00:00
targets = torch . cat ( batch [ " targets " ] ) . long ( )
inputs = torch . cat ( batch [ " logits " ] )
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
self . loss [ name ] = F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) * loss_factor
self . stats [ " acc " ] [ name ] = self . accuracy_metric ( inputs , targets )
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
# to-do: compute loss per individual batch to scale per RVQ level
"""
rvq_loss_factor = self . loss_factor ( " quant " )
if isinstance ( rvq_loss_factor , list ) :
. . .
"""
2024-05-19 16:23:56 +00:00
2024-04-17 02:04:48 +00:00
def forward (
self ,
inputs : list ,
quant_levels : Tensor | None = None ,
state : dict | list | None = None ,
) :
x_list = self . inputs_to_embeddings ( inputs , quant_levels )
2024-04-16 00:54:32 +00:00
x , m = list_to_tensor ( x_list )
2024-04-17 02:04:48 +00:00
# yes, there's a better way.
training = False
for b_i in range ( len ( inputs ) ) :
for i in range ( len ( inputs [ b_i ] ) ) :
name , input = inputs [ b_i ] [ i ]
if name == " targ " :
training = True
device = x . device
batch_size = len ( x_list )
2024-04-16 00:54:32 +00:00
# pad our input and mask, but retain the original length by doing it after
if self . l_padding and x . shape [ 1 ] % self . l_padding != 0 :
# pad input
shape = list ( x . shape )
shape [ 1 ] = self . l_padding - shape [ 1 ] % self . l_padding
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
x = torch . cat ( [ x , padding ] , dim = 1 )
# pad mask
shape [ 2 ] = 1
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
m = torch . cat ( [ m , padding ] , dim = 1 )
x , state , aux_loss = self . _forward (
inputs = x ,
mask = m ,
state = state ,
)
2023-08-02 21:53:35 +00:00
# Remove padding
2023-09-09 01:30:54 +00:00
logits = [ hi [ : li ] for hi , li in zip ( x , map ( len , x_list ) ) ]
2024-04-14 18:12:50 +00:00
2023-08-02 21:53:35 +00:00
# compute loss if the target is given
2024-04-17 02:04:48 +00:00
if training :
2024-05-29 00:29:54 +00:00
self . calc_loss ( inputs = inputs , logits = logits , quant_levels = quant_levels )
2024-05-19 16:23:56 +00:00
# include any additional losses (for example: MoE router)
2023-12-23 01:27:36 +00:00
if aux_loss is not None :
2024-05-19 16:23:56 +00:00
self . loss [ " aux_loss " ] = aux_loss
2023-09-09 01:30:54 +00:00
2024-02-01 03:48:36 +00:00
return ( logits , state ) if state is not None else logits
2023-09-13 02:28:07 +00:00
def sample (
self ,
logits : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
quant_levels : Tensor | None = None ,
temperature : float = 1.0 ,
2023-10-10 22:02:33 +00:00
min_temperature : float = - 1.0 ,
2023-09-13 02:28:07 +00:00
top_k : int = - 100 ,
top_p : float = 1.0 ,
repetition_penalty : float = 1.0 ,
repetition_penalty_decay : float = 0.0 ,
length_penalty : float = 0.0 ,
2023-09-09 01:30:54 +00:00
2023-09-13 02:28:07 +00:00
beam_width : int = 0 ,
2023-09-18 23:55:41 +00:00
mirostat : list [ dict ] | None = None ,
2023-09-13 02:28:07 +00:00
) :
2023-10-10 22:02:33 +00:00
if min_temperature < 0 :
min_temperature = temperature
2023-09-09 01:30:54 +00:00
# (NAR) return the entire generated response
if quant_levels is not None :
logits = [ logit [ - l : ] for logit , l in zip ( logits , map ( len , resps_list ) ) ]
# (AR chunkwise) return the last chunkwise piece
2023-09-02 01:58:29 +00:00
elif self . causal and self . recurrent_chunk_size > 0 :
2023-09-09 01:30:54 +00:00
logits = [ logit [ - l : ] for logit , l in zip ( logits , self . recurrent_chunk_size ) ]
# (AR) return just the last code
2023-08-02 21:53:35 +00:00
else :
2023-09-09 01:30:54 +00:00
logits = [ logit [ - 1 : ] for logit in logits ]
2023-10-09 19:46:17 +00:00
devices = [ logit . device for logit in logits ]
2023-10-13 03:21:43 +00:00
logits = [ logit . to ( device = " cpu " , dtype = logit . dtype if logit . dtype != torch . float16 else torch . float32 ) for logit in logits ]
2023-10-09 19:46:17 +00:00
2023-09-09 01:30:54 +00:00
# perform repetition penalizing
2023-09-13 02:28:07 +00:00
logits = [ reptition_penalize ( logit , previous = resps [ : , - 1 ] , factor = repetition_penalty , decay = repetition_penalty_decay ) for logit , resps in zip ( logits , resps_list ) ]
2023-09-09 01:30:54 +00:00
2023-09-09 01:43:36 +00:00
# (AR) perform length penalizing
2023-09-09 01:30:54 +00:00
if quant_levels is None and self . causal :
2023-09-13 02:28:07 +00:00
logits = [ length_penalize ( logit , length = l + 1 , factor = length_penalty , token = self . stop_token ) for logit , l in zip ( logits , map ( len , resps_list ) ) ]
2023-09-09 01:30:54 +00:00
# perform top_k/top_p filtering of our logits
2023-09-13 02:28:07 +00:00
if top_k > 0 or top_p < 1.0 :
logits = [ top_k_top_p_filtering ( logit , top_k = top_k , top_p = top_p ) for logit in logits ]
2023-10-10 22:02:33 +00:00
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
2024-01-27 01:41:12 +00:00
# epsilon float comparison because I don't trust Python
2023-10-10 22:02:33 +00:00
if abs ( temperature - min_temperature ) > = 0.001 :
logits = [ dynamic_temperature ( logit , temperature = temperature , min_temperature = min_temperature ) for logit in logits ]
2023-10-09 18:01:40 +00:00
else :
logits = [ logit / temperature for logit in logits ]
2023-09-18 23:55:41 +00:00
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None :
# mirostat sampling
return [ mirostat_sample ( logit , state = state ) for logit , state in zip ( logits , mirostat ) ]
2023-09-13 02:28:07 +00:00
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
2023-09-13 18:19:11 +00:00
# returns the logit scores as well to be P-concatted with the previous scores
2023-09-13 02:28:07 +00:00
# to-do: not naively implement beam searching
if beam_width > 1 :
2023-09-13 18:19:11 +00:00
candidates = top_k_logits_list ( logits , beam_width )
2023-10-11 17:25:31 +00:00
res = [ torch . tensor ( token , dtype = torch . int16 ) . unsqueeze ( dim = - 1 ) for batch , token in candidates ]
scores = [ logits [ batch ] . flatten ( ) [ token ] for batch , token in candidates ]
2023-09-13 18:19:11 +00:00
return res , scores
2023-09-13 02:28:07 +00:00
2023-09-09 01:30:54 +00:00
# and sample
2024-05-04 17:05:41 +00:00
return [ Categorical ( logits = logit ) . sample ( ) for logit in logits ]