2024-06-09 01:30:15 +00:00
"""
Core model for handling all VALL - E tasks .
This should handle all the " low " level things such as :
* parsing inputs to sequences
* converting sequences to embeddings
* forward pass
* processing loss and returning logits
Additional functionality ( preparing inputs , generating full audio ) should be delegated to classes that inheret the base model
"""
2023-08-02 21:53:35 +00:00
import math
import torch
import torch . nn . functional as F
import traceback
2023-09-13 02:28:07 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-08-02 21:53:35 +00:00
2024-05-04 18:07:45 +00:00
from typing import Literal , overload , Optional , Tuple
2023-08-02 21:53:35 +00:00
from functools import partial
from einops import rearrange
from torch import Tensor , einsum , nn
2024-04-17 02:04:48 +00:00
from torch . nn import Embedding
2023-08-02 21:53:35 +00:00
from torch . distributions import Categorical
from torch . nn . utils . rnn import pad_sequence
from torch . utils . checkpoint import checkpoint
from torchmetrics . classification import BinaryAccuracy , MulticlassAccuracy , MulticlassPrecision
2024-06-06 01:30:43 +00:00
from . arch import *
2024-05-10 01:28:20 +00:00
from . . utils import wrapper as ml
2024-06-18 03:14:43 +00:00
from . . samplers import reptition_penalize , length_penalize , ban_tokens , top_k_top_p_filtering , dynamic_temperature , top_k_logits_list , mirostat_sample
2023-08-02 21:53:35 +00:00
def _create_mask ( l , device ) :
""" 1 is valid region and 0 is invalid. """
seq = torch . arange ( max ( l ) , device = device ) . unsqueeze ( 0 ) # (1 t)
stop = torch . tensor ( l , device = device ) . unsqueeze ( 1 ) # (b 1)
return ( seq < stop ) . float ( ) # (b t)
def _join ( x : tuple [ Tensor ] , sep : Tensor ) :
"""
Args :
x : ( k t d )
sep : ( d )
"""
ret = x [ 0 ]
for i in range ( 1 , len ( x ) ) :
ret = torch . cat ( ( ret , sep [ None ] , x [ i ] ) , dim = 0 )
return ret
def list_to_tensor ( x_list : list [ Tensor ] , pattern = " t b c -> b t c " ) :
"""
Args :
x_list : [ ( t d ) ]
Returns :
x : ( ? ? ? )
m : ( ? ? ? ) , same as x
"""
l = list ( map ( len , x_list ) )
x = rearrange ( pad_sequence ( x_list ) , pattern )
m = _create_mask ( l , x_list [ 0 ] . device )
m = m . t ( ) . unsqueeze ( - 1 ) # (t b 1)
m = rearrange ( m , pattern )
m = m . to ( x )
return x , m
2023-09-09 01:30:54 +00:00
# automagically parses a batch-list and returns it as a list
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
class Embedding ( nn . Embedding ) :
def forward ( self , x_list : list [ Tensor ] ) - > list [ Tensor ] :
if len ( x_list ) == 0 :
return [ ]
return super ( ) . forward ( torch . cat ( x_list ) ) . split ( [ * map ( len , x_list ) ] )
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
2024-04-29 23:24:05 +00:00
# Deprecated implementation
2023-09-13 18:19:11 +00:00
class MultiEmbedding ( nn . Module ) :
2023-09-08 06:03:24 +00:00
def __init__ ( self , max_n_levels , n_tokens , token_dim , monolithic = False ) :
2023-09-16 05:26:13 +00:00
super ( ) . __init__ ( )
2023-09-08 20:36:26 +00:00
self . monolithic = monolithic
2023-08-02 21:53:35 +00:00
self . max_n_levels = max_n_levels
self . n_tokens = n_tokens
2023-09-08 20:36:26 +00:00
self . weight = nn . Parameter ( torch . randn ( max_n_levels , n_tokens , token_dim ) )
2023-08-02 21:53:35 +00:00
2023-09-07 22:08:38 +00:00
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
2024-06-08 01:46:22 +00:00
def forward ( self , x_list : list [ Tensor ] , quant_level : int | list [ int ] | Tensor | None = None ) - > list [ Tensor ] :
2023-08-02 21:53:35 +00:00
if len ( x_list ) == 0 :
return [ ]
2023-09-08 20:36:26 +00:00
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
2023-09-08 06:03:24 +00:00
if self . monolithic :
2024-06-08 01:46:22 +00:00
w = self . weight [ : 1 ] if quant_level is None or quant_level == 0 else self . weight [ 1 : ]
2023-09-08 06:03:24 +00:00
else :
w = self . weight
2023-08-02 21:53:35 +00:00
2023-09-11 19:13:42 +00:00
padded_x_list = [ ]
2023-08-02 21:53:35 +00:00
2023-09-08 20:36:26 +00:00
for i , xi in enumerate ( x_list ) :
2023-08-02 21:53:35 +00:00
xi = F . one_hot ( xi . to ( torch . int64 ) , num_classes = self . n_tokens ) # t l' k
2023-09-11 19:13:42 +00:00
wi = w . shape [ 0 ] - xi . shape [ 1 ]
xi = F . pad ( xi , ( 0 , 0 , 0 , wi ) ) # t l k
2023-08-02 21:53:35 +00:00
padded_x_list . append ( xi . to ( w ) )
2023-09-11 19:13:42 +00:00
x = torch . cat ( padded_x_list ) # n l k
x = einsum ( " l k d, n l k -> n d " , w , x )
2023-08-02 21:53:35 +00:00
x_list = x . split ( [ * map ( len , x_list ) ] )
2023-09-07 21:48:02 +00:00
2023-09-11 19:13:42 +00:00
return x_list
2023-09-07 21:48:02 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
2024-06-09 01:30:15 +00:00
# _Old, to preserve compat with previous models.
2024-06-06 23:52:41 +00:00
class AudioEmbedding_Old ( nn . Module ) :
2024-04-29 23:24:05 +00:00
def __init__ (
self ,
l_tokens : int , # list of number of tokens (needed because AR resps includes stop token)
token_dim : int , # dimensionality of the embedding
levels : int | None = None , # number of RVQ-bins (I don't remember the specifics)
) :
2023-09-07 14:14:03 +00:00
super ( ) . __init__ ( )
2024-04-29 23:24:05 +00:00
# array of embeddings
# proms are [0, prom_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
2023-09-21 00:10:59 +00:00
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
2024-04-29 23:24:05 +00:00
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
2024-06-07 00:41:26 +00:00
self . weight = nn . ParameterList ( [ nn . Parameter ( torch . Tensor ( [ 1 ] ) ) for i in range ( levels ) ] ) if levels is not None else None
2024-06-06 23:52:41 +00:00
2024-06-09 01:30:15 +00:00
def forward ( self , xi : Tensor , quant_level : Tensor | None = None ) - > Tensor :
2024-04-17 02:04:48 +00:00
# prom
2024-06-09 01:30:15 +00:00
if quant_level is None and xi . shape [ - 1 ] > 1 :
2024-06-07 00:41:26 +00:00
x = sum ( [ self . embeddings [ k ] ( xi [ : , k ] ) * ( self . weight [ k ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
2024-06-09 01:30:15 +00:00
# prom / AR resp
elif quant_level is None or quant_level == 0 :
2024-06-08 20:42:02 +00:00
x = self . embeddings [ 0 ] ( xi if xi . dim ( ) == 1 else xi [ : , 0 ] )
2024-04-17 02:04:48 +00:00
# NAR resp
else :
2024-06-07 00:41:26 +00:00
x = sum ( [ self . embeddings [ k + 1 ] ( xi [ : , k ] ) * ( self . weight [ k + 1 ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
2024-04-17 02:04:48 +00:00
return x
2023-09-07 21:48:02 +00:00
2024-06-09 01:30:15 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
# Mostly to handle some oversights and errors during testing
2024-06-06 23:52:41 +00:00
class AudioEmbedding ( nn . Module ) :
def __init__ (
self ,
2024-06-12 03:28:59 +00:00
l_tokens : list [ int ] , # list of number of tokens (needed because AR resps includes stop token)
2024-06-06 23:52:41 +00:00
token_dim : int , # dimensionality of the embedding
sums : bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
) :
super ( ) . __init__ ( )
# array of embeddings
# proms are [0, prom_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
2024-06-09 01:30:15 +00:00
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
2024-06-06 23:52:41 +00:00
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
#
self . sums = sums
2024-06-09 01:30:15 +00:00
def forward ( self , xi : Tensor , offset : int = 0 ) - > Tensor :
2024-06-08 20:42:02 +00:00
quant_level = 0 if xi . dim ( ) == 1 else xi . shape [ - 1 ] - 1
2024-06-07 00:41:26 +00:00
if self . sums and quant_level > 0 :
2024-06-06 23:52:41 +00:00
x = sum ( [ self . embeddings [ k + offset ] ( xi [ : , k ] ) for k in range ( quant_level ) ] )
else :
2024-06-08 20:42:02 +00:00
k = quant_level
2024-06-07 00:41:26 +00:00
x = self . embeddings [ k + offset ] ( xi if xi . dim ( ) == 1 else xi [ : , k ] )
2024-06-06 23:52:41 +00:00
return x
2024-06-12 03:28:59 +00:00
# per-level classification
class AudioClassifier ( nn . Module ) :
def __init__ (
self ,
l_tokens : list [ int ] , # list of number of tokens (needed because AR resps includes stop token)
token_dim : int , # dimensionality of the embedding
) :
super ( ) . __init__ ( )
self . proj = nn . ModuleList ( [ nn . Linear ( token_dim , n_tokens ) for n_tokens in l_tokens ] )
def forward ( self , xi : Tensor , levels : list [ int ] ) - > Tensor :
return torch . stack ( [ self . proj [ l ] ( x ) for x , l in zip ( xi , levels ) ] )
class Metrics ( nn . Module ) :
def __init__ (
self ,
l_tokens : int | list [ int ] ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = - 100
) :
super ( ) . __init__ ( )
self . accuracy = nn . ModuleList ( [ MulticlassAccuracy (
n_tokens ,
top_k = top_k ,
average = average ,
multidim_average = multidim_average ,
ignore_index = ignore_index ,
) for n_tokens in l_tokens ] )
self . precision = nn . ModuleList ( [ MulticlassPrecision (
n_tokens ,
top_k = top_k ,
average = average ,
multidim_average = multidim_average ,
ignore_index = ignore_index ,
) for n_tokens in l_tokens ] )
def calc_accuracy ( self , inputs , targets , quant_levels ) :
return sum ( [ self . accuracy [ l ] ( input , target ) for target , input , l in zip ( targets , inputs , quant_levels ) ] ) / len ( inputs )
def calc_precision ( self , inputs , targets , quant_levels ) :
return sum ( [ self . precision [ l ] ( input , target ) for target , input , l in zip ( targets , inputs , quant_levels ) ] ) / len ( inputs )
def __call__ ( self , * args , * * kwargs ) :
return dict (
acc = self . calc_accuracy ( * args , * * kwargs ) ,
)
2023-08-02 21:53:35 +00:00
class Base ( nn . Module ) :
2024-06-09 01:30:15 +00:00
# to-do: clean up this property mess
2023-08-02 21:53:35 +00:00
@property
def causal ( self ) - > bool :
raise NotImplementedError
@property
def arch_type ( self ) - > str :
raise NotImplementedError
@property
def norm_type ( self ) :
raise NotImplementedError
@property
def n_prom_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 20:06:33 +00:00
2023-09-06 23:58:35 +00:00
@property
def n_resp_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 20:06:33 +00:00
@property
def n_max_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 01:58:07 +00:00
2023-09-21 00:10:59 +00:00
@property
def n_langs ( self ) - > int :
raise NotImplementedError
2024-04-16 00:54:32 +00:00
2023-08-19 01:58:07 +00:00
@property
def n_tasks ( self ) - > int :
raise NotImplementedError
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
@property
def n_tones ( self ) - > int :
raise NotImplementedError
2023-09-02 01:58:29 +00:00
@property
2024-06-07 01:51:31 +00:00
def causal_size ( self ) - > int :
2023-09-02 01:58:29 +00:00
raise NotImplementedError
2023-09-21 00:10:59 +00:00
2023-09-04 03:46:08 +00:00
@property
def interleave ( self ) - > bool :
return False
2023-09-02 01:58:29 +00:00
2023-09-07 00:33:39 +00:00
@property
2023-09-07 21:48:02 +00:00
def monolithic ( self ) - > bool :
2023-09-07 00:33:39 +00:00
return False
2023-09-07 14:14:03 +00:00
@property
2023-09-11 19:13:42 +00:00
def version ( self ) - > int :
return 1
2023-09-07 14:14:03 +00:00
2024-06-08 20:42:02 +00:00
@property
def capabilities ( self ) - > list [ str ] :
raise NotImplementedError
2023-09-06 23:58:35 +00:00
@property
def stop_token ( self ) :
2024-06-12 04:59:28 +00:00
if " len " in self . capabilities :
return 0
if not self . causal :
2023-09-06 23:58:35 +00:00
raise ValueError ( " Not using stop token! " )
2024-06-06 00:50:06 +00:00
return self . n_audio_tokens
2023-09-06 23:58:35 +00:00
@property
def ignore_index ( self ) :
return - 100
2024-05-19 16:23:56 +00:00
def loss_factor ( self , k ) :
2024-06-06 14:48:43 +00:00
if self . config is None :
2024-05-19 16:23:56 +00:00
return 1.0
2024-06-06 14:48:43 +00:00
return self . config . loss_factors [ k ] if k in self . config . loss_factors else 1.0
2024-05-19 16:23:56 +00:00
2023-08-02 21:53:35 +00:00
def __init__ (
self ,
2024-06-06 00:50:06 +00:00
n_text_tokens : int = 256 ,
n_audio_tokens : int = 1024 ,
2023-08-02 21:53:35 +00:00
d_model : int = 512 ,
n_heads : int = 8 ,
n_layers : int = 12 ,
p_dropout : float = 0.1 ,
2023-09-07 00:33:39 +00:00
2024-04-09 01:14:51 +00:00
n_experts : int = 1 ,
l_padding : int = 0 ,
2023-12-21 00:45:58 +00:00
2024-02-01 03:48:36 +00:00
training = True ,
2023-09-04 03:46:08 +00:00
config = None ,
2023-08-02 21:53:35 +00:00
) :
super ( ) . __init__ ( )
2024-02-01 03:48:36 +00:00
self . training = training
2024-06-06 14:48:43 +00:00
self . config = config
self . gradient_checkpointing = self . config . gradient_checkpointing if self . config is not None else True
2023-09-04 03:46:08 +00:00
2024-06-06 00:50:06 +00:00
self . n_text_tokens = n_text_tokens
self . n_audio_tokens = n_audio_tokens
2023-08-02 21:53:35 +00:00
self . d_model = d_model
self . n_heads = n_heads
self . n_layers = n_layers
2023-12-21 00:45:58 +00:00
self . n_experts = n_experts
2024-04-09 01:14:51 +00:00
self . l_padding = l_padding
2023-08-02 21:53:35 +00:00
2024-06-06 00:50:06 +00:00
n_prom_tokens = n_audio_tokens
2024-06-12 04:59:28 +00:00
2024-06-15 00:42:17 +00:00
# check if requested arch is unavailable
if self . arch_type in ERROR_ARCHES :
raise ERROR_ARCHES [ self . arch_type ]
2024-06-12 04:59:28 +00:00
if " len " not in self . capabilities :
# +1 to include the stop token
n_resp_tokens = n_audio_tokens + self . causal_size
l_tokens = [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 )
else :
n_resp_tokens = n_audio_tokens
l_tokens = [ n_resp_tokens ] * self . n_resp_levels
2023-08-02 21:53:35 +00:00
2024-06-12 03:28:59 +00:00
audio_embedding_sums = self . config . audio_embedding_sums if self . config is not None else True
split_classifiers = self . config . split_classifiers if self . config is not None else True
2024-06-06 00:50:06 +00:00
self . text_emb = Embedding ( n_text_tokens , d_model )
2023-10-22 14:01:47 +00:00
self . langs_emb = None
2024-04-16 00:54:32 +00:00
self . tones_emb = None
2023-10-22 14:01:47 +00:00
self . tasks_emb = None
2024-06-08 20:42:02 +00:00
self . rvq_l_emb = None
self . len_emb = None
2023-09-07 00:33:39 +00:00
2023-09-11 19:13:42 +00:00
if self . version == 1 : # legacy
2023-09-21 00:10:59 +00:00
n_prom_tokens + = ( self . n_tasks - 1 ) # old models have the task tokens in the prom
2023-09-11 19:13:42 +00:00
self . proms_emb = MultiEmbedding ( self . n_prom_levels , n_prom_tokens , d_model )
self . resps_emb = MultiEmbedding ( self . n_resp_levels , n_resp_tokens , d_model , monolithic = self . monolithic )
2024-06-06 23:52:41 +00:00
elif self . version < 5 :
2023-09-21 00:10:59 +00:00
# [1024] * 8
2024-06-06 23:52:41 +00:00
self . proms_emb = AudioEmbedding_Old (
2024-04-29 23:24:05 +00:00
[ n_prom_tokens ] * self . n_prom_levels , d_model ,
levels = self . n_prom_levels if self . version > 3 else None ,
)
2024-06-06 00:50:06 +00:00
# [1024 + STOP] + [1024] * 8
2024-06-06 23:52:41 +00:00
self . resps_emb = AudioEmbedding_Old (
2024-06-12 04:59:28 +00:00
l_tokens , d_model ,
2024-04-29 23:24:05 +00:00
levels = self . n_resp_levels if self . version > 3 else None ,
2024-06-06 23:52:41 +00:00
)
else :
self . proms_emb = AudioEmbedding (
[ n_prom_tokens ] * self . n_prom_levels , d_model ,
2024-06-12 03:28:59 +00:00
sums = audio_embedding_sums ,
2024-06-06 23:52:41 +00:00
)
self . resps_emb = AudioEmbedding (
2024-06-12 04:59:28 +00:00
l_tokens , d_model ,
2024-06-12 03:28:59 +00:00
sums = audio_embedding_sums ,
2024-04-29 23:24:05 +00:00
)
2023-09-21 00:10:59 +00:00
2024-06-06 00:50:06 +00:00
# useless since I actually removed using these with the input processing overhaul...
2023-10-12 01:38:40 +00:00
if self . version > = 3 :
2023-10-22 14:01:47 +00:00
self . langs_emb = Embedding ( self . n_langs , d_model ) if self . n_langs > 0 else None
self . tasks_emb = Embedding ( self . n_tasks , d_model ) if self . n_tasks > 0 else None
2024-06-06 00:50:06 +00:00
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
2024-04-16 00:54:32 +00:00
if self . version > = 4 :
self . tones_emb = Embedding ( self . n_tones , d_model ) if self . n_tones > 0 else None
2023-08-02 21:53:35 +00:00
2024-06-06 00:50:06 +00:00
# mamba requires this if a model does both AR and NAR tasks
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
# this ***might*** let me also unify the proms_emb and resps_embedding
2024-06-05 04:23:31 +00:00
if self . version > = 5 :
2024-06-08 20:42:02 +00:00
self . rvq_l_emb = Embedding ( self . n_resp_levels , d_model )
# experimental NAR-only mode
self . len_emb = Embedding ( 11 , d_model ) if " len " in self . capabilities else None
2024-06-05 04:23:31 +00:00
2024-06-06 00:50:06 +00:00
# this would be nicer to be a stop token or live inside an embedding
2023-08-02 21:53:35 +00:00
self . sep = nn . Parameter ( torch . randn ( d_model ) )
2023-08-27 00:53:23 +00:00
2024-05-10 01:28:20 +00:00
# ick, there has to be a better way
2024-06-06 14:48:43 +00:00
hf_attention = self . config . attention if self . config is not None else None
2024-05-11 22:14:05 +00:00
2024-06-06 14:48:43 +00:00
if self . config . attention == " auto " :
2024-05-11 22:14:05 +00:00
if " flash " in AVAILABLE_ATTENTIONS :
2024-06-06 14:48:43 +00:00
self . config . attention = " flash "
2024-05-11 22:14:05 +00:00
elif " xformers " in AVAILABLE_ATTENTIONS :
2024-06-06 14:48:43 +00:00
self . config . attention = " xformers "
2024-05-11 22:14:05 +00:00
else :
2024-06-06 14:48:43 +00:00
self . config . attention = " mem_efficient "
2024-05-11 22:14:05 +00:00
2024-06-06 14:48:43 +00:00
if self . config . attention in [ " xformers " , " mem_efficient " , " math " , " flash " ] :
2024-05-11 22:14:05 +00:00
hf_attention = None
2024-06-06 14:48:43 +00:00
if self . config . attention not in AVAILABLE_ATTENTIONS :
raise ValueError ( f " Requesting attention ` { self . config . attention } ` but is not available. Currently available: { AVAILABLE_ATTENTIONS } " )
2024-05-10 01:28:20 +00:00
2023-08-02 21:53:35 +00:00
if self . arch_type == " transformer " :
self . sin_emb = SinusoidalEmbedding ( d_model )
self . blocks = nn . ModuleList ( [ TransformerBlock (
d_model = d_model ,
n_heads = n_heads ,
2024-02-01 03:48:36 +00:00
p_dropout = p_dropout if training else 0.0 ,
2023-08-19 01:58:07 +00:00
causal = self . causal ,
2023-08-02 21:53:35 +00:00
norm_type = self . norm_type ,
n_levels = self . n_resp_levels ,
) for _ in range ( n_layers ) ] )
2024-05-10 01:28:20 +00:00
elif self . arch_type in [ " mistral " , " mixtral " ] :
2024-02-01 03:48:36 +00:00
if n_experts < = 1 :
self . model = MistralModel ( MistralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2024-02-01 03:48:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-06 14:48:43 +00:00
num_key_value_heads = self . config . kv_heads if self . config . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2024-02-01 03:48:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-06 14:48:43 +00:00
num_key_value_heads = self . config . kv_heads if self . config . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
2024-05-10 04:25:44 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-05-10 04:25:44 +00:00
2024-05-12 13:22:39 +00:00
#if training:
# self.model.training = True
2023-12-23 01:27:36 +00:00
elif self . arch_type == " llama " :
if n_experts < = 1 :
self . model = LlamaModel ( LlamaConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2023-12-23 01:27:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2023-12-23 01:27:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
2024-05-10 04:15:52 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-05-10 04:15:52 +00:00
2024-05-12 13:22:39 +00:00
#if training:
# self.model.training = True
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-12-26 03:20:32 +00:00
kwargs = dict (
2023-12-23 01:27:36 +00:00
vocab_size = n_resp_tokens ,
2023-08-02 21:53:35 +00:00
decoder_embed_dim = d_model ,
2023-10-05 21:39:46 +00:00
decoder_value_embed_dim = d_model * 2 ,
2023-08-02 21:53:35 +00:00
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
2024-02-01 03:48:36 +00:00
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2023-10-05 21:39:46 +00:00
activation_fn = " gelu " ,
2024-05-10 01:28:20 +00:00
use_layernorm = self . version < 3 ,
use_biases = self . version < 3 ,
use_glu = self . version > = 3 ,
2023-08-02 21:53:35 +00:00
2024-06-07 01:51:31 +00:00
chunkwise_recurrent = self . causal and self . causal_size > 0 ,
recurrent_chunkwise_size = self . causal_size if self . causal else 0 ,
2023-08-02 21:53:35 +00:00
no_output_layer = True ,
decoder_normalize_before = True ,
2023-09-21 00:10:59 +00:00
2024-06-07 01:51:31 +00:00
rotary_embedding_base = 10000
2023-12-26 03:20:32 +00:00
)
if n_experts > 1 :
kwargs . update ( dict (
use_xmoe = True ,
moe_freq = 1 ,
moe_expert_count = n_experts ,
moe_gating_use_fp32 = False ,
) )
2023-12-21 00:45:58 +00:00
2023-12-26 03:20:32 +00:00
self . model = RetNetDecoder ( RetNetConfig ( * * kwargs ) )
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
kwargs = dict (
vocab_size = n_resp_tokens ,
decoder_embed_dim = d_model ,
decoder_value_embed_dim = d_model * 2 ,
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2024-04-09 01:14:51 +00:00
activation_fn = " gelu " ,
use_glu = False , # self.version >= 3,
2024-06-07 01:51:31 +00:00
recurrent_chunk_size = self . causal_size if self . causal else 0 ,
2024-04-09 01:14:51 +00:00
decoder_normalize_before = True ,
deepnorm = False ,
subln = True ,
)
self . model = RetNetDecoder_HF ( RetNetConfig_HF ( * * kwargs ) )
2024-05-12 13:22:39 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-12 13:22:39 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
self . model = BitNetTransformer (
num_tokens = n_resp_tokens ,
dim = d_model ,
depth = n_layers ,
heads = n_heads ,
ff_mult = 4 ,
2024-06-04 02:28:49 +00:00
gradient_checkpointing = self . gradient_checkpointing ,
2024-03-01 02:29:17 +00:00
)
2024-06-05 03:41:22 +00:00
elif self . arch_type in [ " mamba " , " mamba2 " ] :
self . model = MambaMixelModel (
vocab_size = n_resp_tokens ,
d_model = d_model ,
2024-06-14 01:08:22 +00:00
n_layer = n_layers ,
d_intermediate = d_model * 4 ,
2024-06-15 17:08:03 +00:00
ssm_cfg = { " layer " : " Mamba2 " , " use_mem_eff_path " : False } if self . arch_type == " mamba2 " else { } ,
2024-06-05 03:41:22 +00:00
rms_norm = True ,
fused_add_norm = True ,
2024-06-15 17:08:03 +00:00
residual_in_fp32 = False ,
2024-06-05 03:41:22 +00:00
#attn_layer_idx=attn_layer_idx,
#attn_cfg=attn_cfg,
#initializer_cfg=initializer_cfg,
)
self . model . gradient_checkpointing = self . gradient_checkpointing
2024-06-15 00:42:17 +00:00
elif self . arch_type in [ " mamba2-hf " ] :
self . model = Mamba2Model_HF ( Mamba2Config_HF (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
expand = 4 ,
num_hidden_layers = n_layers ,
is_encoder_decoder = False ,
is_decoder = True ,
use_triton_kernels = False , # the entire reason is to NOT use triton (because V100s hate it)
2024-06-15 17:08:03 +00:00
residual_in_fp32 = False , # breaks for AMP inference
2024-06-15 00:42:17 +00:00
) )
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-06-12 03:28:59 +00:00
elif self . arch_type == " mmfreelm " :
self . model = HGRNBitModel ( HGRNBitConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_heads = n_heads ,
#hidden_act="gelu",
#is_encoder_decoder=False,
#is_decoder=True,
attn_mode = hf_attention ,
#gradient_checkpointing=self.gradient_checkpointing,
) )
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
#if training:
# self.model.training = True
2024-03-01 02:29:17 +00:00
else :
raise RuntimeError ( f ' Unknown arch specified: { self . arch_type } ' )
2023-08-04 01:26:36 +00:00
2024-06-15 00:42:17 +00:00
if hasattr ( self . model , " embeddings " ) :
del self . model . embeddings
2024-06-06 14:48:43 +00:00
if self . config . attention in [ " xformers " , " auto " , " mem_efficient " , " math " , " flash " ] :
self . model = ml . replace_attention ( self . model , klass = LlamaAttention , target = LlamaAttention_Base , mode = self . config . attention )
2024-05-10 01:28:20 +00:00
2024-06-12 03:28:59 +00:00
if not split_classifiers :
self . classifier = nn . Linear ( d_model , n_resp_tokens )
self . classifiers = None
self . accuracy_metric = MulticlassAccuracy (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
2023-08-02 21:53:35 +00:00
2024-06-12 03:28:59 +00:00
self . precision_metric = MulticlassPrecision (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
self . metrics = None
else :
self . classifier = None
2024-06-12 04:59:28 +00:00
self . classifiers = AudioClassifier ( l_tokens , d_model )
2024-06-12 03:28:59 +00:00
self . accuracy_metric = None
self . precision_metric = None
2024-06-12 04:59:28 +00:00
self . metrics = Metrics ( l_tokens )
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
def _forward (
2023-08-02 21:53:35 +00:00
self ,
2024-04-16 00:54:32 +00:00
inputs ,
mask = None ,
state = None ,
2023-08-02 21:53:35 +00:00
) :
2024-04-16 00:54:32 +00:00
x = inputs
m = mask . squeeze ( - 1 ) . int ( )
2023-12-23 01:27:36 +00:00
aux_loss = None
2024-04-14 18:12:50 +00:00
2023-12-23 01:27:36 +00:00
# HF transformer derived model
2024-04-16 00:54:32 +00:00
if self . arch_type in [ " llama " , " mistral " , " mixtral " ] :
2023-12-23 01:27:36 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2023-12-23 01:27:36 +00:00
inputs_embeds = x ,
2024-02-01 03:48:36 +00:00
past_key_values = state ,
2024-04-14 18:12:50 +00:00
use_cache = True ,
2024-02-01 03:48:36 +00:00
# return_dict=True,
2023-12-23 01:27:36 +00:00
)
2024-06-06 23:52:41 +00:00
if self . n_experts > 1 and self . training :
2023-12-23 01:27:36 +00:00
kwargs [ " output_router_logits " ] = True
2023-09-02 01:58:29 +00:00
2023-12-23 01:27:36 +00:00
t = self . model ( * * kwargs )
2024-02-01 03:48:36 +00:00
2023-12-23 01:27:36 +00:00
x = t [ 0 ]
2024-02-01 03:48:36 +00:00
if state is not None :
state = t [ 1 ]
2024-06-06 23:52:41 +00:00
if self . n_experts > 1 and self . training :
2023-12-23 01:27:36 +00:00
router_logits = t [ - 1 ]
aux_loss = self . model . config . router_aux_loss_coef * load_balancing_loss_func ( router_logits , self . model . config . num_local_experts , self . model . config . num_experts_per_tok )
elif self . arch_type == " transformer " :
2023-09-12 21:04:45 +00:00
# ensures we specify a quant_level for the transformer implementation's AdaLN
2023-09-07 14:14:03 +00:00
l = torch . zeros ( ( batch_size , ) , dtype = torch . int32 ) if quant_levels is None else quant_levels
l = l . to ( device )
2023-09-12 21:04:45 +00:00
# inject position information
x = self . sin_emb . add_pe ( x )
# pass our inputs through the transformer
2023-08-02 21:53:35 +00:00
for block in self . blocks :
2024-04-16 00:54:32 +00:00
x = block ( x , m , l )
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-09-12 21:04:45 +00:00
# pass our inputs through the RetNet
2023-12-23 01:27:36 +00:00
x , _ = self . model ( x , incremental_state = state , token_embeddings = x , features_only = True )
2023-12-26 03:20:32 +00:00
if _ is not None and " l_aux " in _ and self . n_experts > 1 :
2023-12-23 01:27:36 +00:00
aux_loss = torch . sum ( torch . stack ( [ t for t in _ [ " l_aux " ] if t is not None ] ) ) * 0.001
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
2024-04-14 18:12:50 +00:00
first = state is None or len ( state ) == 0
2024-04-09 01:14:51 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2024-04-14 18:12:50 +00:00
inputs_embeds = x if first else x [ : , - 1 , : ] . unsqueeze ( 1 ) ,
past_key_values = None if first else state ,
use_cache = True ,
forward_impl = ' parallel ' if first else ' recurrent ' ,
return_dict = True ,
2024-04-09 01:14:51 +00:00
)
2024-04-14 18:12:50 +00:00
out = self . model ( * * kwargs )
x = out . last_hidden_state
2024-04-09 01:14:51 +00:00
if state is not None :
2024-04-14 18:12:50 +00:00
state = out . past_key_values
2024-06-05 03:41:22 +00:00
elif self . arch_type in [ " mamba " , " mamba2 " ] :
x = self . model ( hidden_states = x )
2024-06-15 00:42:17 +00:00
elif self . arch_type == " mamba2-hf " :
first = state is None or len ( state ) == 0
kwargs = dict (
inputs_embeds = x ,
cache_params = state ,
return_dict = True ,
)
out = self . model ( * * kwargs )
x = out . last_hidden_state
if state is not None :
state = out . cache_params
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
x = self . model ( x )
2024-06-12 03:28:59 +00:00
elif self . arch_type == " mmfreelm " :
x = self . model (
attention_mask = m ,
inputs_embeds = x ,
)
x = x [ 0 ]
2024-04-14 18:12:50 +00:00
2023-09-12 21:04:45 +00:00
# output projection layer with masking
2024-06-12 03:28:59 +00:00
if self . classifier is not None :
x = self . classifier ( x ) * mask
2024-04-14 18:12:50 +00:00
2024-04-16 00:54:32 +00:00
return x , state , aux_loss
2024-04-17 02:04:48 +00:00
def inputs (
2024-04-16 00:54:32 +00:00
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
2024-04-17 02:04:48 +00:00
2024-04-16 00:54:32 +00:00
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
2024-06-08 20:42:02 +00:00
len_list : list [ Tensor ] | None = None ,
task_list : list [ str ] | None = None ,
2024-06-05 04:23:31 +00:00
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None
2024-04-16 00:54:32 +00:00
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
2024-04-17 02:04:48 +00:00
inputs = [ [ ] for _ in range ( batch_size ) ]
for i in range ( batch_size ) :
2024-06-05 04:23:31 +00:00
quant_level = quant_levels [ i ] if quant_levels is not None else 0
2024-06-08 20:42:02 +00:00
task_type = task_list [ i ] if task_list is not None else " tts "
inputs [ i ] . append ( ( " task " , task_type ) )
# <text><sep><rvq lvl><sep><prom><sep><resp>
if task_type == " tts " :
if text_list is not None :
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
if self . rvq_l_emb is not None :
inputs [ i ] . append ( ( " quant_level " , torch . Tensor ( [ quant_level ] ) . to ( device = device , dtype = torch . int16 ) ) )
if proms_list is not None :
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
if resps_list is not None :
inputs [ i ] . append ( ( " resp " , resps_list [ i ] ) )
# <text><sep><rvq lvl><prom><sep><len>
elif task_type == " len " :
# throw an error so we don't silently train without this
if self . len_emb is None :
raise Exception ( f " Requesting task ` { task_type } ` but corresponding embedding is not defined. " )
if text_list is not None :
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
# technically will always be level 0 but for the sake of keeing the input formatting coherent...
if self . rvq_l_emb is not None :
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
quant_levels [ i ] = 0
2024-06-08 21:01:34 +00:00
inputs [ i ] . append ( ( " quant_level " , torch . Tensor ( [ 0 ] ) . to ( device = device , dtype = torch . int16 ) ) )
2024-06-08 20:42:02 +00:00
if proms_list is not None :
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
if len_list is not None :
inputs [ i ] . append ( ( " len " , len_list [ i ] ) )
# "encode" length to tokens for 0-9 + stop
elif resps_list is not None :
# yes this could be encoded better
inputs [ i ] . append ( ( " len " , torch . Tensor ( [ 0 ] + [ int ( i ) for i in str ( resps_list [ i ] . shape [ 0 ] ) ] + [ 10 ] ) . to ( device = device , dtype = torch . int16 ) ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return inputs
def inputs_to_embeddings (
self ,
inputs : list ,
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None
2024-04-17 02:04:48 +00:00
) :
x_list = [ ]
2024-05-29 00:29:54 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-04-17 02:04:48 +00:00
batch = [ ]
2024-06-05 04:23:31 +00:00
quant_level = quant_levels [ batch_index ] if quant_levels is not None else 0
2024-05-29 00:29:54 +00:00
for name , input in batch_input :
2024-06-08 20:42:02 +00:00
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
2024-04-17 02:04:48 +00:00
embedding = None
2024-06-08 20:42:02 +00:00
if name == " task " :
# noop
# *maybe* inject a token for specifying task type
. . .
continue
elif name == " text " :
2024-04-17 02:04:48 +00:00
embedding = self . text_emb ( input )
2024-06-08 20:42:02 +00:00
elif name == " quant_level " and self . rvq_l_emb is not None :
embedding = self . rvq_l_emb ( input )
2024-06-05 15:30:04 +00:00
elif name == " lang " and self . langs_emb is not None :
2024-04-17 02:04:48 +00:00
embedding = self . langs_emb ( input )
elif name == " prom " :
2024-06-08 01:34:36 +00:00
# get RVQ level 0, or up to targetted RVQ level inference
2024-06-09 01:30:15 +00:00
if self . version < = 4 :
embedding = self . proms_emb ( input if quant_level == 0 else input [ : , : quant_level ] )
else :
if quant_level == 0 :
embedding = self . proms_emb ( input if input . dim ( ) == 1 else input [ : , : 1 ] , offset = 0 )
else :
embedding = self . proms_emb ( input if input . dim ( ) == 1 else input [ : , : quant_level ] , offset = 0 )
2024-06-05 15:30:04 +00:00
elif name == " tone " and self . tones_emb is not None :
2024-04-17 02:04:48 +00:00
embedding = self . tones_emb ( input )
elif name == " resp " :
2024-06-08 20:42:02 +00:00
if " len " in self . capabilities and quant_level == 0 :
# fill with "stop" tokens for NAR-only model
2024-06-09 01:30:15 +00:00
embedding = self . resps_emb (
torch . full_like ( input if input . dim ( ) == 1 else input [ . . . , 0 ] , self . stop_token ) ,
offset = 0
)
2024-06-08 20:42:02 +00:00
else :
# get RVQ level 0, or up to targetted RVQ level inference
2024-06-09 01:30:15 +00:00
if self . version < = 4 :
embedding = self . resps_emb ( input if quant_level == 0 else input [ : , : quant_level ] , quant_level )
else :
embedding = self . resps_emb (
input if input . dim ( ) == 1 or quant_level == 0 else input [ : , : quant_level ] ,
offset = 0 if quant_level == 0 or " len " in self . capabilities else 1
)
2024-06-08 20:42:02 +00:00
elif name == " len " and self . len_emb is not None :
embedding = self . len_emb ( input )
2024-04-17 02:04:48 +00:00
else :
2024-06-08 20:42:02 +00:00
# should probably raise an exception so things aren't processed silently
2024-04-17 02:04:48 +00:00
continue
batch . append ( embedding )
2024-06-08 20:42:02 +00:00
2024-04-17 02:04:48 +00:00
x_list . append ( _join ( batch , self . sep ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return x_list
2024-05-29 00:29:54 +00:00
def calc_loss (
2024-04-17 02:04:48 +00:00
self ,
inputs : list ,
2024-05-29 00:29:54 +00:00
logits ,
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None ,
2024-04-17 02:04:48 +00:00
) :
2024-05-29 00:29:54 +00:00
# old, "naive" way, no loss factoring
2024-06-06 14:48:43 +00:00
if not self . config . loss_factors :
2024-05-29 00:29:54 +00:00
target_list = [ ]
2024-06-08 20:42:02 +00:00
task_list = [ ]
2024-06-08 01:34:36 +00:00
2024-06-05 04:23:31 +00:00
for batch_index , batch in enumerate ( inputs ) :
2024-06-08 01:34:36 +00:00
quant_level = quant_levels [ batch_index ]
2024-05-29 00:29:54 +00:00
target = [ ]
for name , input in batch :
2024-06-08 20:42:02 +00:00
if name == " task " :
task_list . append ( input )
elif name == " prom " :
2024-06-08 01:29:25 +00:00
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self . version < 4 or ( self . version > = 5 and self . config . audio_embedding_sums ) :
target . append ( torch . full_like ( input [ . . . , 0 ] , self . ignore_index ) )
# we *CAN* directly map to proms
else :
2024-06-08 20:42:02 +00:00
target . append ( input if input . dim ( ) == 1 else input [ : , quant_level ] )
2024-06-06 23:52:41 +00:00
elif name == " resp " :
2024-06-08 01:34:36 +00:00
target . append ( input if input . dim ( ) == 1 else input [ : , quant_level ] )
2024-06-08 20:42:02 +00:00
elif name in [ " text " , " quant_level " , " lang " , " tone " , " len " ] :
2024-05-29 00:29:54 +00:00
target . append ( input )
target_list . append ( _join ( target , torch . tensor ( self . ignore_index , device = target [ - 1 ] . device ) ) )
2024-06-06 02:02:05 +00:00
batch_size = len ( target_list )
2024-05-29 00:29:54 +00:00
# modify only for the AR so it can properly behave like a transformer
2024-06-06 02:02:05 +00:00
for i in range ( batch_size ) :
2024-06-08 20:42:02 +00:00
if " len " in self . capabilities :
if task_list [ i ] != " len " :
continue
else :
if quant_levels is not None and quant_levels [ i ] > 0 :
continue
2024-04-17 02:04:48 +00:00
2024-06-07 01:51:31 +00:00
l = self . causal_size
logits [ i ] = logits [ i ] [ . . . , : - l , : ] # shift the target so that token n...
target_list [ i ] = target_list [ i ] [ . . . , l : ] # predicts token n + 1
2024-04-17 02:04:48 +00:00
2024-06-06 00:50:06 +00:00
# see comments for the split-loss calc cross_entropy call
if False :
target = torch . cat ( target_list )
inputs = torch . cat ( logits )
self . loss = dict (
# "nll" was in the original implementation and should actually just be called something else
nll = F . cross_entropy ( inputs , target , ignore_index = self . ignore_index )
)
2024-06-12 03:28:59 +00:00
self . stats = self . metrics ( inputs , targets , quant_levels ) if self . metrics is not None else dict (
2024-06-06 00:50:06 +00:00
acc = self . accuracy_metric ( inputs , target ) ,
# precision = self.precision_metric( inputs, target ),
)
else :
self . loss = dict (
2024-06-08 01:34:36 +00:00
nll = sum ( [ F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) for targets , inputs in zip ( target_list , logits ) ] ) / batch_size
2024-06-06 00:50:06 +00:00
)
2024-06-12 04:59:28 +00:00
self . stats = self . metrics ( logits , target_list , quant_levels ) if self . metrics is not None else dict (
2024-06-06 02:02:05 +00:00
acc = sum ( [ self . accuracy_metric ( inputs , targets ) for targets , inputs in zip ( target_list , logits ) ] ) / batch_size
2024-06-06 00:50:06 +00:00
)
2024-04-17 02:04:48 +00:00
2024-05-29 00:29:54 +00:00
return
2024-05-19 16:23:56 +00:00
2024-06-06 00:50:06 +00:00
"""
# considerations:
# * split losses does not maintain the entire sequence
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
# + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
"""
2024-05-29 00:29:54 +00:00
self . loss = dict ( )
self . stats = dict ( acc = dict ( ) )
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
info = { }
2024-06-06 02:02:05 +00:00
batch_size = len ( inputs )
2024-05-29 00:29:54 +00:00
for i , batch in enumerate ( inputs ) :
2024-06-08 01:34:36 +00:00
quant_level = quant_levels [ i ]
2024-05-19 16:23:56 +00:00
2024-05-29 00:29:54 +00:00
it = 0
for name , input in batch :
# do not use resp
if name == " resp " :
2024-06-06 23:52:41 +00:00
input = input if input . dim ( ) == 1 else input [ : , quant_level ]
2024-05-29 00:29:54 +00:00
# select prom level
2024-06-08 01:34:36 +00:00
elif name == " prom " :
2024-06-08 20:42:02 +00:00
input = input [ : , quant_level ]
# meta-input, no corresponding token at the moment
elif name == " task " :
continue
2024-05-29 00:29:54 +00:00
seq_len = input . shape [ 0 ]
2024-06-05 04:23:31 +00:00
2024-05-29 00:29:54 +00:00
logit = logits [ i ] [ it : it + seq_len ]
it + = seq_len + 1 # +1 to incorporate the separator
# for the AR, shift sequence so that it predicts the next token
2024-06-06 00:50:06 +00:00
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
2024-06-08 20:42:02 +00:00
if quant_level == 0 and seq_len > 1 :
2024-06-07 01:51:31 +00:00
l = self . causal_size
logit = logit [ . . . , : - l , : ]
input = input [ . . . , l : ] # shift sequence to the right by one (or causal chunk size)
2024-05-29 00:29:54 +00:00
if name not in info :
info [ name ] = {
" targets " : [ ] ,
" logits " : [ ] ,
}
2024-06-06 00:50:06 +00:00
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
info [ name ] [ " targets " ] . append ( input . long ( ) )
info [ name ] [ " logits " ] . append ( logit )
2024-05-29 00:29:54 +00:00
for name , batch in info . items ( ) :
loss_factor = self . loss_factor ( name )
2024-06-08 20:42:02 +00:00
if name not in [ " text " , " prom " , " resp " , " len " ] :
2024-06-05 04:48:51 +00:00
continue
2024-05-29 00:29:54 +00:00
if loss_factor == 0.0 :
continue
2024-05-27 13:43:00 +00:00
2024-06-06 00:50:06 +00:00
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
# to-do: set this to a var
if False :
targets = torch . cat ( batch [ " targets " ] ) . long ( )
inputs = torch . cat ( batch [ " logits " ] )
self . loss [ name ] = F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) * loss_factor
2024-06-05 04:48:51 +00:00
self . stats [ " acc " ] [ name ] = self . accuracy_metric ( inputs , targets )
2024-06-06 00:50:06 +00:00
# probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else :
2024-06-06 02:02:05 +00:00
self . loss [ name ] = sum ( [ F . cross_entropy ( inputs , targets , ignore_index = self . ignore_index ) * loss_factor for targets , inputs in zip ( batch [ " targets " ] , batch [ " logits " ] ) ] ) / batch_size
2024-06-12 03:28:59 +00:00
if self . metrics is not None :
metrics = self . metrics ( batch [ " logits " ] , batch [ " targets " ] , quant_levels )
self . stats [ " acc " ] [ name ] = metrics [ " acc " ]
else :
self . stats [ " acc " ] [ name ] = sum ( [ self . accuracy_metric ( inputs , targets ) for targets , inputs in zip ( batch [ " targets " ] , batch [ " logits " ] ) ] ) / batch_size
2024-05-19 16:23:56 +00:00
2024-04-17 02:04:48 +00:00
def forward (
self ,
inputs : list ,
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None ,
2024-04-17 02:04:48 +00:00
state : dict | list | None = None ,
) :
x_list = self . inputs_to_embeddings ( inputs , quant_levels )
2024-04-16 00:54:32 +00:00
x , m = list_to_tensor ( x_list )
2024-04-17 02:04:48 +00:00
2024-06-06 23:52:41 +00:00
training = self . training
2024-04-17 02:04:48 +00:00
# yes, there's a better way.
2024-06-06 23:52:41 +00:00
"""
2024-04-17 02:04:48 +00:00
training = False
2024-06-06 00:50:06 +00:00
for batch_index , batch in enumerate ( inputs ) :
for name , input in batch :
2024-04-17 02:04:48 +00:00
if name == " targ " :
training = True
2024-06-06 23:52:41 +00:00
"""
2024-04-17 02:04:48 +00:00
device = x . device
batch_size = len ( x_list )
2024-06-12 03:28:59 +00:00
# pure AR
if quant_levels is None :
quant_levels = [ 0 for _ in range ( batch_size ) ]
2024-04-16 00:54:32 +00:00
# pad our input and mask, but retain the original length by doing it after
if self . l_padding and x . shape [ 1 ] % self . l_padding != 0 :
# pad input
shape = list ( x . shape )
shape [ 1 ] = self . l_padding - shape [ 1 ] % self . l_padding
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
x = torch . cat ( [ x , padding ] , dim = 1 )
# pad mask
shape [ 2 ] = 1
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
m = torch . cat ( [ m , padding ] , dim = 1 )
x , state , aux_loss = self . _forward (
inputs = x ,
mask = m ,
state = state ,
)
2023-08-02 21:53:35 +00:00
2024-06-12 03:28:59 +00:00
if self . classifiers is not None :
x = self . classifiers ( x , levels = quant_levels ) * m
2023-08-02 21:53:35 +00:00
# Remove padding
2023-09-09 01:30:54 +00:00
logits = [ hi [ : li ] for hi , li in zip ( x , map ( len , x_list ) ) ]
2024-04-14 18:12:50 +00:00
2023-08-02 21:53:35 +00:00
# compute loss if the target is given
2024-04-17 02:04:48 +00:00
if training :
2024-05-29 00:29:54 +00:00
self . calc_loss ( inputs = inputs , logits = logits , quant_levels = quant_levels )
2024-05-19 16:23:56 +00:00
# include any additional losses (for example: MoE router)
2023-12-23 01:27:36 +00:00
if aux_loss is not None :
2024-05-19 16:23:56 +00:00
self . loss [ " aux_loss " ] = aux_loss
2023-09-09 01:30:54 +00:00
2024-02-01 03:48:36 +00:00
return ( logits , state ) if state is not None else logits
2023-09-13 02:28:07 +00:00
def sample (
self ,
logits : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None ,
2023-09-13 02:28:07 +00:00
temperature : float = 1.0 ,
2023-10-10 22:02:33 +00:00
min_temperature : float = - 1.0 ,
2023-09-13 02:28:07 +00:00
top_k : int = - 100 ,
top_p : float = 1.0 ,
repetition_penalty : float = 1.0 ,
repetition_penalty_decay : float = 0.0 ,
length_penalty : float = 0.0 ,
2023-09-09 01:30:54 +00:00
2023-09-13 02:28:07 +00:00
beam_width : int = 0 ,
2023-09-18 23:55:41 +00:00
mirostat : list [ dict ] | None = None ,
2023-09-13 02:28:07 +00:00
) :
2023-10-10 22:02:33 +00:00
if min_temperature < 0 :
min_temperature = temperature
2024-06-06 00:50:06 +00:00
2023-09-09 01:30:54 +00:00
# (NAR) return the entire generated response
2024-06-06 00:50:06 +00:00
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
2023-09-09 01:30:54 +00:00
if quant_levels is not None :
logits = [ logit [ - l : ] for logit , l in zip ( logits , map ( len , resps_list ) ) ]
# (AR chunkwise) return the last chunkwise piece
2024-06-07 01:51:31 +00:00
elif self . causal :
logits = [ logit [ - self . causal_size : ] for logit in logits ]
2023-09-09 01:30:54 +00:00
2023-10-09 19:46:17 +00:00
devices = [ logit . device for logit in logits ]
2023-10-13 03:21:43 +00:00
logits = [ logit . to ( device = " cpu " , dtype = logit . dtype if logit . dtype != torch . float16 else torch . float32 ) for logit in logits ]
2023-10-09 19:46:17 +00:00
2023-09-09 01:30:54 +00:00
# perform repetition penalizing
2024-06-14 01:08:22 +00:00
if " len " not in self . capabilities :
logits = [ reptition_penalize ( logit , previous = resps [ : , - 1 ] , factor = repetition_penalty , decay = repetition_penalty_decay ) for logit , resps in zip ( logits , resps_list ) ]
2023-09-09 01:30:54 +00:00
2024-06-09 22:11:38 +00:00
# argmax instead
if temperature < = 0.0 :
return [ logit . argmax ( dim = 1 ) for logit in logits ]
2023-09-09 01:43:36 +00:00
# (AR) perform length penalizing
2023-09-09 01:30:54 +00:00
if quant_levels is None and self . causal :
2023-09-13 02:28:07 +00:00
logits = [ length_penalize ( logit , length = l + 1 , factor = length_penalty , token = self . stop_token ) for logit , l in zip ( logits , map ( len , resps_list ) ) ]
2024-06-18 03:14:43 +00:00
# (NAR) disable stop token
else :
logits = [ ban_tokens ( logit , tokens = [ self . stop_token ] ) for logit , l in zip ( logits , map ( len , resps_list ) ) ]
2023-09-09 01:30:54 +00:00
# perform top_k/top_p filtering of our logits
2023-09-13 02:28:07 +00:00
if top_k > 0 or top_p < 1.0 :
logits = [ top_k_top_p_filtering ( logit , top_k = top_k , top_p = top_p ) for logit in logits ]
2023-10-10 22:02:33 +00:00
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
2024-01-27 01:41:12 +00:00
# epsilon float comparison because I don't trust Python
2023-10-10 22:02:33 +00:00
if abs ( temperature - min_temperature ) > = 0.001 :
logits = [ dynamic_temperature ( logit , temperature = temperature , min_temperature = min_temperature ) for logit in logits ]
2023-10-09 18:01:40 +00:00
else :
logits = [ logit / temperature for logit in logits ]
2023-09-18 23:55:41 +00:00
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None :
# mirostat sampling
return [ mirostat_sample ( logit , state = state ) for logit , state in zip ( logits , mirostat ) ]
2023-09-13 02:28:07 +00:00
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
2023-09-13 18:19:11 +00:00
# returns the logit scores as well to be P-concatted with the previous scores
2023-09-13 02:28:07 +00:00
# to-do: not naively implement beam searching
if beam_width > 1 :
2023-09-13 18:19:11 +00:00
candidates = top_k_logits_list ( logits , beam_width )
2023-10-11 17:25:31 +00:00
res = [ torch . tensor ( token , dtype = torch . int16 ) . unsqueeze ( dim = - 1 ) for batch , token in candidates ]
scores = [ logits [ batch ] . flatten ( ) [ token ] for batch , token in candidates ]
2023-09-13 18:19:11 +00:00
return res , scores
2023-09-13 02:28:07 +00:00
2023-09-09 01:30:54 +00:00
# and sample
2024-05-04 17:05:41 +00:00
return [ Categorical ( logits = logit ) . sample ( ) for logit in logits ]