2024-06-09 01:30:15 +00:00
"""
Core model for handling all VALL - E tasks .
This should handle all the " low " level things such as :
* parsing inputs to sequences
* converting sequences to embeddings
* forward pass
* processing loss and returning logits
Additional functionality ( preparing inputs , generating full audio ) should be delegated to classes that inheret the base model
"""
2023-08-02 21:53:35 +00:00
import math
import torch
import torch . nn . functional as F
2024-07-25 00:35:17 +00:00
import random
2023-09-13 02:28:07 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-08-02 21:53:35 +00:00
2024-10-12 02:18:26 +00:00
from time import perf_counter
from collections import namedtuple
2024-05-04 18:07:45 +00:00
from typing import Literal , overload , Optional , Tuple
2023-08-02 21:53:35 +00:00
from functools import partial
from einops import rearrange
from torch import Tensor , einsum , nn
2024-04-17 02:04:48 +00:00
from torch . nn import Embedding
2023-08-02 21:53:35 +00:00
from torch . distributions import Categorical
from torch . nn . utils . rnn import pad_sequence
from torch . utils . checkpoint import checkpoint
from torchmetrics . classification import BinaryAccuracy , MulticlassAccuracy , MulticlassPrecision
2024-06-06 01:30:43 +00:00
from . arch import *
2024-11-10 04:57:34 +00:00
from . . utils import wrapper as ml , clamp
2024-07-30 00:15:07 +00:00
from . . samplers import *
2024-06-30 00:46:11 +00:00
from . . emb . qnt import encode_as_embedding
2024-07-18 21:16:14 +00:00
# yuck, kind of needed
from . . data import get_task_symmap
2024-11-02 16:49:05 +00:00
# these seem more elegant than a dict
2024-12-07 05:53:46 +00:00
Logits = namedtuple ( ' Logits ' , [ ' logits ' , ' state ' , ' inputs ' , ' loss ' , ' attentions ' , ' hidden_states ' , ' exited_layer ' ] )
2024-11-12 22:41:58 +00:00
Sampled = namedtuple ( ' Sampled ' , [ ' ids ' , ' logits ' , ' scores ' , ' entropy ' ] )
2024-10-31 01:05:45 +00:00
LossStats = namedtuple ( ' LossStats ' , [ ' loss ' , ' stats ' ] )
2024-10-12 02:18:26 +00:00
2024-07-18 19:18:34 +00:00
"""
from . . utils . pattern import DelayedPatternProvider , VALLEPattern
"""
2024-11-13 19:31:17 +00:00
summed_embeddings_task = [ " stt " ]
special_tasks = [ " len " , " stt " ]
2024-11-15 04:17:47 +00:00
non_tokened_names = [ " task " , " dropout_mask " , " classifier_level " ]
task_outputs = {
" tts " : " resp " ,
" stt " : " text " ,
" len " : " len " ,
}
2024-11-13 19:31:17 +00:00
2024-12-22 04:52:10 +00:00
# yuck
def _get_offsets ( ) :
return {
2024-12-22 21:05:45 +00:00
" text " : ( 0 , 256 ) ,
" quant_level " : ( 256 , 264 ) ,
" lang " : ( 264 , 270 ) ,
" task " : ( 270 , 279 ) ,
" len " : ( 279 , 290 ) ,
" tone " : ( 290 , 291 ) ,
" sep " : ( 291 , 292 ) ,
" prom|0 " : ( 292 , 1316 ) ,
" prom|1 " : ( 1316 , 2340 ) ,
" prom|2 " : ( 2340 , 3364 ) ,
" prom|3 " : ( 3364 , 4388 ) ,
" prom|4 " : ( 4388 , 5412 ) ,
" prom|5 " : ( 5412 , 6436 ) ,
" prom|6 " : ( 6436 , 7460 ) ,
" prom|7 " : ( 7460 , 8484 ) ,
" resps|AR:0:0 " : ( 8484 , 9509 ) ,
" resps|NAR:0:1 " : ( 9509 , 10533 ) ,
" resps|NAR:1:2 " : ( 10533 , 11557 ) ,
" resps|NAR:2:3 " : ( 11557 , 12581 ) ,
" resps|NAR:3:4 " : ( 12581 , 13605 ) ,
" resps|NAR:4:5 " : ( 13605 , 14629 ) ,
" resps|NAR:5:6 " : ( 14629 , 15653 ) ,
" resps|NAR:6:7 " : ( 15653 , 16677 ) ,
" resps|NAR:0:0 " : ( 16677 , 17702 ) ,
2024-12-22 04:52:10 +00:00
}
2024-11-08 03:19:14 +00:00
def _dropout_mask ( input , p = None ) :
# cosine scheduling
if p is None :
t = random . random ( )
p = math . cos ( t * math . pi * 0.5 )
2024-11-10 00:04:59 +00:00
seq = [ random . random ( ) < p for _ in range ( input . shape [ 0 ] ) ]
mask = torch . tensor ( seq , dtype = torch . bool , device = input . device )
return mask
2024-11-07 17:32:11 +00:00
2023-08-02 21:53:35 +00:00
def _create_mask ( l , device ) :
""" 1 is valid region and 0 is invalid. """
seq = torch . arange ( max ( l ) , device = device ) . unsqueeze ( 0 ) # (1 t)
stop = torch . tensor ( l , device = device ) . unsqueeze ( 1 ) # (b 1)
return ( seq < stop ) . float ( ) # (b t)
def _join ( x : tuple [ Tensor ] , sep : Tensor ) :
"""
Args :
x : ( k t d )
sep : ( d )
"""
ret = x [ 0 ]
for i in range ( 1 , len ( x ) ) :
ret = torch . cat ( ( ret , sep [ None ] , x [ i ] ) , dim = 0 )
return ret
def list_to_tensor ( x_list : list [ Tensor ] , pattern = " t b c -> b t c " ) :
"""
Args :
x_list : [ ( t d ) ]
Returns :
x : ( ? ? ? )
m : ( ? ? ? ) , same as x
"""
l = list ( map ( len , x_list ) )
x = rearrange ( pad_sequence ( x_list ) , pattern )
m = _create_mask ( l , x_list [ 0 ] . device )
2024-11-03 02:00:21 +00:00
"""
2023-08-02 21:53:35 +00:00
m = m . t ( ) . unsqueeze ( - 1 ) # (t b 1)
m = rearrange ( m , pattern )
2024-11-03 02:00:21 +00:00
"""
m = m . to ( x ) . int ( )
2023-08-02 21:53:35 +00:00
return x , m
2024-08-05 00:56:21 +00:00
def _interleave_sequence_reshape ( input : list [ torch . Tensor ] , dim = - 1 ) :
shape = ( input [ 0 ] . shape [ 0 ] * len ( input ) , input [ 0 ] . shape [ dim ] )
return torch . concat ( [ i . t ( ) for i in input ] ) . t ( ) . reshape ( shape )
def _interleave_sequence_flatten ( input : list [ torch . Tensor ] ) :
return torch . concat ( [ i . t ( ) for i in input ] ) . t ( ) . flatten ( )
2023-09-09 01:30:54 +00:00
# automagically parses a batch-list and returns it as a list
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
class Embedding ( nn . Embedding ) :
def forward ( self , x_list : list [ Tensor ] ) - > list [ Tensor ] :
if len ( x_list ) == 0 :
return [ ]
return super ( ) . forward ( torch . cat ( x_list ) ) . split ( [ * map ( len , x_list ) ] )
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
2024-04-29 23:24:05 +00:00
# Deprecated implementation
2023-09-13 18:19:11 +00:00
class MultiEmbedding ( nn . Module ) :
2023-09-08 06:03:24 +00:00
def __init__ ( self , max_n_levels , n_tokens , token_dim , monolithic = False ) :
2023-09-16 05:26:13 +00:00
super ( ) . __init__ ( )
2023-09-08 20:36:26 +00:00
self . monolithic = monolithic
2023-08-02 21:53:35 +00:00
self . max_n_levels = max_n_levels
self . n_tokens = n_tokens
2023-09-08 20:36:26 +00:00
self . weight = nn . Parameter ( torch . randn ( max_n_levels , n_tokens , token_dim ) )
2023-08-02 21:53:35 +00:00
2024-08-04 01:23:36 +00:00
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resps_emb)
2023-09-07 22:08:38 +00:00
# I imagine this is an oversight in the NAR.
2024-06-08 01:46:22 +00:00
def forward ( self , x_list : list [ Tensor ] , quant_level : int | list [ int ] | Tensor | None = None ) - > list [ Tensor ] :
2023-08-02 21:53:35 +00:00
if len ( x_list ) == 0 :
return [ ]
2023-09-08 20:36:26 +00:00
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
2024-08-04 01:23:36 +00:00
# the NAR cannot share RVQ-bin level 0 with the AR for the resps_emb
2023-09-08 06:03:24 +00:00
if self . monolithic :
2024-06-08 01:46:22 +00:00
w = self . weight [ : 1 ] if quant_level is None or quant_level == 0 else self . weight [ 1 : ]
2023-09-08 06:03:24 +00:00
else :
w = self . weight
2023-08-02 21:53:35 +00:00
2023-09-11 19:13:42 +00:00
padded_x_list = [ ]
2023-08-02 21:53:35 +00:00
2023-09-08 20:36:26 +00:00
for i , xi in enumerate ( x_list ) :
2023-08-02 21:53:35 +00:00
xi = F . one_hot ( xi . to ( torch . int64 ) , num_classes = self . n_tokens ) # t l' k
2023-09-11 19:13:42 +00:00
wi = w . shape [ 0 ] - xi . shape [ 1 ]
xi = F . pad ( xi , ( 0 , 0 , 0 , wi ) ) # t l k
2023-08-02 21:53:35 +00:00
padded_x_list . append ( xi . to ( w ) )
2023-09-11 19:13:42 +00:00
x = torch . cat ( padded_x_list ) # n l k
x = einsum ( " l k d, n l k -> n d " , w , x )
2023-08-02 21:53:35 +00:00
x_list = x . split ( [ * map ( len , x_list ) ] )
2023-09-07 21:48:02 +00:00
2023-09-11 19:13:42 +00:00
return x_list
2023-09-07 21:48:02 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
2024-06-09 01:30:15 +00:00
# _Old, to preserve compat with previous models.
2024-06-06 23:52:41 +00:00
class AudioEmbedding_Old ( nn . Module ) :
2024-04-29 23:24:05 +00:00
def __init__ (
self ,
l_tokens : int , # list of number of tokens (needed because AR resps includes stop token)
token_dim : int , # dimensionality of the embedding
levels : int | None = None , # number of RVQ-bins (I don't remember the specifics)
) :
2023-09-07 14:14:03 +00:00
super ( ) . __init__ ( )
2024-04-29 23:24:05 +00:00
# array of embeddings
2024-07-16 00:59:48 +00:00
# proms are [0, resp_levels]
2024-04-29 23:24:05 +00:00
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
2023-09-21 00:10:59 +00:00
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
2024-04-29 23:24:05 +00:00
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
2024-08-04 03:10:21 +00:00
self . weight = nn . ParameterList ( [ nn . Parameter ( torch . tensor ( [ 1 ] ) ) for i in range ( levels ) ] ) if levels is not None else None
2024-06-06 23:52:41 +00:00
2024-06-09 01:30:15 +00:00
def forward ( self , xi : Tensor , quant_level : Tensor | None = None ) - > Tensor :
2024-04-17 02:04:48 +00:00
# prom
2024-06-09 01:30:15 +00:00
if quant_level is None and xi . shape [ - 1 ] > 1 :
2024-06-07 00:41:26 +00:00
x = sum ( [ self . embeddings [ k ] ( xi [ : , k ] ) * ( self . weight [ k ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
2024-06-09 01:30:15 +00:00
# prom / AR resp
elif quant_level is None or quant_level == 0 :
2024-06-08 20:42:02 +00:00
x = self . embeddings [ 0 ] ( xi if xi . dim ( ) == 1 else xi [ : , 0 ] )
2024-04-17 02:04:48 +00:00
# NAR resp
else :
2024-06-07 00:41:26 +00:00
x = sum ( [ self . embeddings [ k + 1 ] ( xi [ : , k ] ) * ( self . weight [ k + 1 ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
2024-04-17 02:04:48 +00:00
return x
2023-09-07 21:48:02 +00:00
2024-06-09 01:30:15 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
# Mostly to handle some oversights and errors during testing
2024-06-06 23:52:41 +00:00
class AudioEmbedding ( nn . Module ) :
def __init__ (
self ,
2024-06-12 03:28:59 +00:00
l_tokens : list [ int ] , # list of number of tokens (needed because AR resps includes stop token)
2024-06-06 23:52:41 +00:00
token_dim : int , # dimensionality of the embedding
2024-06-30 02:46:35 +00:00
sums : bool = True , # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
2024-11-13 19:31:17 +00:00
l_names : list [ str ] = [ ] , # names to map to indices
2024-06-06 23:52:41 +00:00
) :
super ( ) . __init__ ( )
# array of embeddings
2024-07-16 00:59:48 +00:00
# proms are [0, resp_levels]
2024-06-06 23:52:41 +00:00
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
2024-06-30 02:46:35 +00:00
# further experimentation is needed to see if this actually is useful
2024-06-06 23:52:41 +00:00
self . sums = sums
2024-11-13 19:31:17 +00:00
#
self . names = l_names
2024-06-06 23:52:41 +00:00
2024-11-13 19:31:17 +00:00
def forward ( self , xi : Tensor , offset : int | None = None , quant_level : int | None = None , name : str | None = None , sums = None ) - > Tensor :
2024-09-08 03:13:49 +00:00
if sums is None :
sums = self . sums
2024-11-15 04:17:47 +00:00
if quant_level is None :
quant_level = 0 if xi . dim ( ) == 1 else xi . shape [ - 1 ] - 1
2024-09-08 03:13:49 +00:00
2024-11-13 19:31:17 +00:00
# handle mapping from name
if name in self . names :
offset = self . names . index ( name )
2024-11-15 04:17:47 +00:00
offset - = quant_level # offset by quant level since it'll iterate up that many levels
2024-06-07 00:41:26 +00:00
if self . sums and quant_level > 0 :
2024-06-06 23:52:41 +00:00
x = sum ( [ self . embeddings [ k + offset ] ( xi [ : , k ] ) for k in range ( quant_level ) ] )
else :
2024-06-08 20:42:02 +00:00
k = quant_level
2024-06-07 00:41:26 +00:00
x = self . embeddings [ k + offset ] ( xi if xi . dim ( ) == 1 else xi [ : , k ] )
2024-06-06 23:52:41 +00:00
return x
2024-11-09 04:46:26 +00:00
# time-step embedding
# for the NAR-len, since it probably most likely requires encoding the timestep
class TimeEmbedding ( nn . Module ) :
def __init__ (
self ,
d_model
) :
super ( ) . __init__ ( )
self . emb = SinusoidalEmbedding ( d_model )
self . mlp = nn . Sequential (
nn . Linear ( d_model , d_model * 4 ) ,
nn . SiLU ( ) ,
nn . Linear ( d_model * 4 , d_model ) ,
)
def forward ( self , t ) :
t = self . emb ( t )
t = self . mlp ( t )
return t
2024-06-12 03:28:59 +00:00
# per-level classification
2024-09-06 01:43:20 +00:00
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
class Classifiers ( nn . Module ) :
2024-06-12 03:28:59 +00:00
def __init__ (
self ,
l_tokens : list [ int ] , # list of number of tokens (needed because AR resps includes stop token)
token_dim : int , # dimensionality of the embedding
2024-12-20 23:13:37 +00:00
l_names : list [ str ] | None = None , # list of names to map to each classifier,
bias : bool = True ,
2024-06-12 03:28:59 +00:00
) :
super ( ) . __init__ ( )
2024-12-20 23:13:37 +00:00
self . proj = nn . ModuleList ( [ nn . Linear ( token_dim , n_tokens , bias = bias ) for n_tokens in l_tokens ] )
2024-11-13 19:31:17 +00:00
self . names = l_names
def indices (
self ,
names
) :
if isinstance ( names [ - 1 ] , int ) :
return names
return [ self . names . index ( name ) for name in names ]
2024-06-12 03:28:59 +00:00
2024-12-07 18:31:54 +00:00
def forward ( self , xi : Tensor , levels : list [ int ] | None = None , names : list [ str ] | None = None , stack = False ) - > Tensor :
2024-07-19 20:33:31 +00:00
dtype = xi . dtype
device = xi . device
2024-11-13 19:31:17 +00:00
if levels and isinstance ( levels [ - 1 ] , str ) :
names = levels
levels = [ ]
# map names to levels
if names and not levels :
levels = [ self . names . index ( name ) for name in names ]
2024-07-19 20:33:31 +00:00
xi = [ self . proj [ l ] ( x ) for x , l in zip ( xi , levels ) ]
2024-12-07 18:31:54 +00:00
if not stack :
return xi
2024-07-19 20:33:31 +00:00
# pad if needed
2024-08-02 01:12:06 +00:00
# to-do: validate that this causes ZERO issues
2024-12-07 18:31:54 +00:00
# addendum: this does cause problems
2024-08-01 01:35:09 +00:00
max_size = max ( [ x . shape [ - 1 ] for x in xi ] )
xi = [
#x if l == 0 else
x if x . shape [ - 1 ] == max_size else
2024-09-06 16:41:41 +00:00
torch . cat ( [ x , torch . full ( ( x . shape [ 0 ] , max_size - x . shape [ - 1 ] ) , - float ( " inf " ) , device = device , dtype = dtype ) ] , dim = - 1 )
2024-08-01 01:35:09 +00:00
for x , l in zip ( xi , levels )
]
2024-07-19 20:33:31 +00:00
return torch . stack ( xi )
2024-06-12 03:28:59 +00:00
class Metrics ( nn . Module ) :
def __init__ (
self ,
l_tokens : int | list [ int ] ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = - 100
) :
super ( ) . __init__ ( )
self . accuracy = nn . ModuleList ( [ MulticlassAccuracy (
n_tokens ,
top_k = top_k ,
average = average ,
multidim_average = multidim_average ,
ignore_index = ignore_index ,
) for n_tokens in l_tokens ] )
self . precision = nn . ModuleList ( [ MulticlassPrecision (
n_tokens ,
top_k = top_k ,
average = average ,
multidim_average = multidim_average ,
ignore_index = ignore_index ,
) for n_tokens in l_tokens ] )
2024-11-13 19:31:17 +00:00
def calc_accuracy ( self , inputs , targets , classifier_levels ) :
return sum ( [ self . accuracy [ l ] ( input [ : , : self . accuracy [ l ] . num_classes ] , target ) for target , input , l in zip ( targets , inputs , classifier_levels ) ] ) / len ( inputs )
2024-06-12 03:28:59 +00:00
2024-11-13 19:31:17 +00:00
def calc_precision ( self , inputs , targets , classifier_levels ) :
return sum ( [ self . precision [ l ] ( input [ : , : self . precision [ l ] . num_classes ] , target ) for target , input , l in zip ( targets , inputs , classifier_levels ) ] ) / len ( inputs )
2024-06-12 03:28:59 +00:00
def __call__ ( self , * args , * * kwargs ) :
return dict (
acc = self . calc_accuracy ( * args , * * kwargs ) ,
)
2023-08-02 21:53:35 +00:00
class Base ( nn . Module ) :
2024-05-19 16:23:56 +00:00
def loss_factor ( self , k ) :
2024-06-06 14:48:43 +00:00
if self . config is None :
2024-05-19 16:23:56 +00:00
return 1.0
2024-11-15 04:17:47 +00:00
return self . config . loss_factor ( k )
2024-05-19 16:23:56 +00:00
2024-08-04 00:51:00 +00:00
def _prune ( self , l : Tensor , stop = None ) :
if stop is None :
stop = self . stop_token
indices = ( l == stop ) . nonzero ( )
if len ( indices ) == 0 :
return l
return l [ : indices . min ( ) . item ( ) ]
2024-07-18 19:18:34 +00:00
# these probably need to live in an interleaved model, as pattern-ing is targeted for a sole AR model
"""
def codes_to_pattern ( self , codes ) :
# expand if not batched
if codes . dim ( ) == 2 :
codes = codes . unsqueeze ( 0 )
# [batch, timestep, rvq level] (B, T, K) => [batch, rvq level, timestep] (B, K, T)
codes = codes . permute ( 0 , 2 , 1 )
B , K , T = codes . shape
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
pattern = self . pattern_provider . get_pattern ( T )
sequence_codes , sequence_indexes , sequence_mask = pattern . build_pattern_sequence (
codes . contiguous ( ) , self . stop_token , keep_only_valid_steps = False ,
)
# (B, K, T) => (B, T, K)
return sequence_codes . permute ( 0 , 2 , 1 )
def logits_from_pattern ( self , logits , pattern ) :
logits = logits . permute ( 0 , 3 , 1 , 2 ) # [B, card, K, S]
logits , logits_indexes , logits_mask = pattern . revert_pattern_logits (
logits , float ( ' nan ' ) , keep_only_valid_steps = False
)
logits = logits . permute ( 0 , 2 , 3 , 1 ) # [B, K, T, card]
logits_mask = logits_mask [ None , : , : ] . expand ( B , - 1 , - 1 ) # [K, T] -> [B, K, T]
return logits , logits_mask
"""
2023-08-02 21:53:35 +00:00
def __init__ (
self ,
2024-06-06 00:50:06 +00:00
n_text_tokens : int = 256 ,
n_audio_tokens : int = 1024 ,
2023-08-02 21:53:35 +00:00
d_model : int = 512 ,
n_heads : int = 8 ,
n_layers : int = 12 ,
p_dropout : float = 0.1 ,
2023-09-07 00:33:39 +00:00
2024-04-09 01:14:51 +00:00
n_experts : int = 1 ,
l_padding : int = 0 ,
2023-12-21 00:45:58 +00:00
2024-12-06 05:05:52 +00:00
training = True ,
2024-08-27 00:33:51 +00:00
attention = None ,
2023-09-04 03:46:08 +00:00
config = None ,
2023-08-02 21:53:35 +00:00
) :
super ( ) . __init__ ( )
2024-02-01 03:48:36 +00:00
self . training = training
2024-12-06 05:05:52 +00:00
self . teaching = False
2024-06-06 14:48:43 +00:00
self . config = config
2023-09-04 03:46:08 +00:00
2024-06-06 00:50:06 +00:00
self . n_text_tokens = n_text_tokens
self . n_audio_tokens = n_audio_tokens
2023-08-02 21:53:35 +00:00
self . d_model = d_model
self . n_heads = n_heads
self . n_layers = n_layers
2023-12-21 00:45:58 +00:00
self . n_experts = n_experts
2024-04-09 01:14:51 +00:00
self . l_padding = l_padding
2023-08-02 21:53:35 +00:00
2024-08-04 00:51:00 +00:00
self . ignore_index = - 100
self . n_resp_levels = self . config . resp_levels if self . config else n_resp_levels
self . n_max_levels = self . config . max_levels if self . config else n_resp_levels
self . capabilities = self . config . capabilities if self . config else [ " ar " , " nar " ]
self . gradient_checkpointing = self . config . gradient_checkpointing if self . config is not None else True
self . stop_token = self . n_audio_tokens # id 1024
self . causal = " ar " in self . capabilities or " len " in self . capabilities
self . version = self . config . version if self . config is not None else 5
2024-11-16 21:49:06 +00:00
self . causal_size = self . config . experimental . causal_size if self . config is not None else ( 1 if self . causal else 0 )
2024-06-30 16:00:12 +00:00
2024-08-04 00:51:00 +00:00
self . arch_type = self . config . arch_type if self . config is not None else " llama "
2024-06-12 04:59:28 +00:00
2024-06-15 00:42:17 +00:00
# check if requested arch is unavailable
if self . arch_type in ERROR_ARCHES :
raise ERROR_ARCHES [ self . arch_type ]
2024-08-04 00:51:00 +00:00
2024-08-27 00:33:51 +00:00
if not attention :
attention = self . config . attention if self . config is not None else " auto "
2024-12-06 05:05:52 +00:00
# crunge
if self . config is not None and config . teacher :
self . teaching = True
self . training = False
2024-08-27 00:33:51 +00:00
attention_backend = attention
2024-08-03 13:40:39 +00:00
audio_embedding_sums = self . config . experimental . audio_embedding_sums if self . config is not None else False
split_classifiers = self . config . experimental . split_classifiers if self . config is not None else False
tie_classifier_to_embedding = self . config . experimental . tie_classifier_to_embedding if self . config is not None else False
audio_embedding_mode = self . config . experimental . audio_embedding_mode if self . config is not None else " "
unified_position_ids = self . config . experimental . unified_position_ids if self . config is not None else True
2024-08-05 00:56:21 +00:00
interleave = self . config . experimental . interleave if self . config is not None else False
2024-11-22 19:44:43 +00:00
noncausal_masks = self . config . experimental . noncausal_masks if self . config is not None else False
2024-12-20 23:13:37 +00:00
classifiers_bias = self . config . experimental . classifiers_bias if self . config is not None else False
2024-10-31 18:24:48 +00:00
2024-11-17 23:04:07 +00:00
masking_ratio = self . config . experimental . masking_ratio if self . config is not None else False
2024-11-15 04:17:47 +00:00
ignore_inputs_for_loss = self . config . experimental . ignore_inputs_for_loss if self . config is not None else False
2024-10-31 01:05:45 +00:00
layerskip = self . config . experimental . layerskip if self . config is not None else False
2024-10-31 18:24:48 +00:00
layerskip_r = self . config . experimental . layerskip_r if self . config is not None else 2
layerskip_p_max = self . config . experimental . layerskip_p_max if self . config is not None else 0.1
layerskip_e_scale = self . config . experimental . layerskip_e_scale if self . config is not None else 0.1
2024-06-15 00:42:17 +00:00
2024-08-04 00:51:00 +00:00
n_tasks = self . config . tasks if self . config is not None else 8
n_langs = self . config . langs if self . config is not None else 2
n_tones = self . config . tones if self . config is not None else 1
2024-08-05 00:56:21 +00:00
# pure AR
if " nar " not in self . capabilities :
n_resp_tokens = n_audio_tokens + 1
l_tokens = [ n_resp_tokens ] * self . n_resp_levels
2024-11-13 19:31:17 +00:00
resp_l_names = [ f ' AR: { i } : { i } ' for i in range ( self . n_resp_levels ) ]
2024-12-07 18:31:54 +00:00
classifier_l_tokens = [ n_resp_tokens ] * self . n_resp_levels
2024-11-13 17:38:58 +00:00
# NAR-len model
2024-11-15 04:17:47 +00:00
elif " len " in self . capabilities :
2024-11-13 17:38:58 +00:00
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self . causal_size > 0 else 0 )
2024-11-15 04:17:47 +00:00
if " ar " in self . capabilities :
l_tokens = [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 ) + [ n_resp_tokens ]
2024-12-07 18:31:54 +00:00
classifier_l_tokens = [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 ) + [ n_resp_tokens - 1 ]
2024-11-15 04:17:47 +00:00
resp_l_names = [ ' AR:0:0 ' ] + [ f ' NAR: { i } : { i + 1 } ' for i in range ( self . n_resp_levels - 1 ) ] + [ ' NAR:0:0 ' ]
else :
l_tokens = [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 )
2024-12-07 18:31:54 +00:00
classifier_l_tokens = [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 )
2024-11-15 04:17:47 +00:00
resp_l_names = [ ' NAR:0:0 ' ] + [ f ' NAR: { i } : { i + 1 } ' for i in range ( self . n_resp_levels - 1 ) ]
2024-11-13 17:38:58 +00:00
# AR+NAR model
2024-11-07 17:32:11 +00:00
else :
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self . causal_size > 0 else 0 )
l_tokens = [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 )
2024-11-13 19:31:17 +00:00
resp_l_names = [ ' AR:0:0 ' ] + [ f ' NAR: { i } : { i + 1 } ' for i in range ( self . n_resp_levels - 1 ) ]
2024-12-07 18:31:54 +00:00
classifier_l_tokens = [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 )
2024-11-13 19:31:17 +00:00
2024-12-07 18:31:54 +00:00
classifier_l_tokens + = [ n_text_tokens ]
2024-11-15 04:18:33 +00:00
classifier_l_names = resp_l_names + [ " stt " ]
2024-11-15 04:17:47 +00:00
2024-11-15 04:18:33 +00:00
if " len " in self . capabilities :
2024-12-07 18:31:54 +00:00
classifier_l_tokens + = [ 11 ]
2024-11-15 04:17:47 +00:00
classifier_l_names + = [ " len " ]
2024-08-03 03:25:49 +00:00
2024-12-22 22:15:24 +00:00
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
2024-12-22 04:52:10 +00:00
self . n_vocab = n_vocab
2024-07-17 00:52:41 +00:00
self . unified_position_ids = unified_position_ids
2024-08-05 00:56:21 +00:00
self . interleave = interleave
2024-10-31 01:05:45 +00:00
self . layerskip = layerskip
2024-11-10 00:12:54 +00:00
self . inject_timestep_embedding = False # results in bad output
2024-11-17 23:04:07 +00:00
self . masking_ratio = masking_ratio
2024-11-15 04:17:47 +00:00
self . ignore_inputs_for_loss = ignore_inputs_for_loss
2024-11-22 19:44:43 +00:00
self . noncausal_masks = noncausal_masks
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
2024-11-22 22:45:40 +00:00
"""
2024-11-22 19:44:43 +00:00
if noncausal_masks :
attention_backend = " default "
2024-11-22 22:45:40 +00:00
"""
2024-06-12 03:28:59 +00:00
2024-06-06 00:50:06 +00:00
self . text_emb = Embedding ( n_text_tokens , d_model )
2023-10-22 14:01:47 +00:00
self . langs_emb = None
2024-04-16 00:54:32 +00:00
self . tones_emb = None
2023-10-22 14:01:47 +00:00
self . tasks_emb = None
2024-06-08 20:42:02 +00:00
self . rvq_l_emb = None
self . len_emb = None
2024-07-25 00:35:17 +00:00
# it would be nicer for these to be a token or live inside an embedding
self . sep = nn . Parameter ( torch . randn ( d_model ) )
2024-11-07 15:10:18 +00:00
self . dropout_token = nn . Parameter ( torch . randn ( d_model ) )
2023-09-07 00:33:39 +00:00
2023-09-11 19:13:42 +00:00
if self . version == 1 : # legacy
2024-08-04 00:51:00 +00:00
n_audio_tokens + = ( n_tasks - 1 ) # old models have the task tokens in the prom
2024-07-16 00:59:48 +00:00
self . proms_emb = MultiEmbedding ( self . n_resp_levels , n_audio_tokens , d_model )
2023-09-11 19:13:42 +00:00
self . resps_emb = MultiEmbedding ( self . n_resp_levels , n_resp_tokens , d_model , monolithic = self . monolithic )
2024-06-06 23:52:41 +00:00
elif self . version < 5 :
2023-09-21 00:10:59 +00:00
# [1024] * 8
2024-06-06 23:52:41 +00:00
self . proms_emb = AudioEmbedding_Old (
2024-07-16 00:59:48 +00:00
[ n_audio_tokens ] * self . n_resp_levels , d_model ,
levels = self . n_resp_levels if self . version > 3 else None ,
2024-04-29 23:24:05 +00:00
)
2024-06-06 00:50:06 +00:00
# [1024 + STOP] + [1024] * 8
2024-06-06 23:52:41 +00:00
self . resps_emb = AudioEmbedding_Old (
2024-06-12 04:59:28 +00:00
l_tokens , d_model ,
2024-04-29 23:24:05 +00:00
levels = self . n_resp_levels if self . version > 3 else None ,
2024-06-06 23:52:41 +00:00
)
else :
self . proms_emb = AudioEmbedding (
2024-07-16 00:59:48 +00:00
[ n_audio_tokens ] * self . n_resp_levels , d_model ,
2024-12-20 17:15:29 +00:00
sums = audio_embedding_sums == " prom " or audio_embedding_sums == True ,
2024-06-06 23:52:41 +00:00
)
self . resps_emb = AudioEmbedding (
2024-06-12 04:59:28 +00:00
l_tokens , d_model ,
2024-12-20 17:15:29 +00:00
sums = audio_embedding_sums == " resp " or audio_embedding_sums == True ,
2024-11-13 19:31:17 +00:00
l_names = resp_l_names ,
2024-04-29 23:24:05 +00:00
)
2023-09-21 00:10:59 +00:00
2023-10-12 01:38:40 +00:00
if self . version > = 3 :
2024-08-04 00:51:00 +00:00
self . langs_emb = Embedding ( n_langs , d_model ) if n_langs > 0 else None
self . tasks_emb = Embedding ( n_tasks , d_model ) if n_tasks > 0 else None
2024-12-20 01:25:27 +00:00
self . capabilities + = [ " lang " ]
2024-06-06 00:50:06 +00:00
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
2024-04-16 00:54:32 +00:00
if self . version > = 4 :
2024-08-04 00:51:00 +00:00
self . tones_emb = Embedding ( n_tones , d_model ) if n_tones > 0 else None
2023-08-02 21:53:35 +00:00
2024-06-06 00:50:06 +00:00
# mamba requires this if a model does both AR and NAR tasks
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
# this ***might*** let me also unify the proms_emb and resps_embedding
2024-06-05 04:23:31 +00:00
if self . version > = 5 :
2024-11-07 15:10:18 +00:00
# "len" RVQ level-0 gets an additional token
2024-11-10 18:19:48 +00:00
self . rvq_l_emb = Embedding ( self . n_resp_levels , d_model )
2024-06-08 20:42:02 +00:00
# experimental NAR-only mode
2024-11-10 18:19:48 +00:00
self . len_emb = Embedding ( 11 , d_model )
2024-11-22 05:14:52 +00:00
self . time_emb = None # TimeEmbedding(d_model) # if not masking_ratio else None
2024-06-05 04:23:31 +00:00
2024-08-04 00:51:00 +00:00
if attention_backend == " auto " :
2024-11-05 00:00:33 +00:00
attention_backend = " sdpa "
"""
2024-08-09 00:38:55 +00:00
if AVAILABLE_ATTENTIONS :
attention_backend = AVAILABLE_ATTENTIONS [ 0 ]
2024-08-07 01:42:39 +00:00
else :
2024-10-12 16:27:55 +00:00
attention_backend = " default "
2024-11-05 00:00:33 +00:00
"""
2024-07-30 00:53:37 +00:00
2024-08-04 00:51:00 +00:00
hf_attention = attention_backend
2024-08-27 00:13:34 +00:00
HF_ATTENTIONS = [ " eager " , " sdpa " , " flash_attention_2 " ]
2024-05-11 22:14:05 +00:00
2024-08-27 00:13:34 +00:00
if attention_backend not in HF_ATTENTIONS :
2024-05-11 22:14:05 +00:00
hf_attention = None
2024-08-04 00:51:00 +00:00
if attention_backend not in AVAILABLE_ATTENTIONS :
raise ValueError ( f " Requesting attention ` { attention_backend } ` but is not available. Currently available: { AVAILABLE_ATTENTIONS } " )
2024-05-10 01:28:20 +00:00
2024-11-07 15:10:18 +00:00
# override any requested padding size
2024-08-27 00:13:34 +00:00
if attention_backend == " flash_attn_v100 " :
self . l_padding = 32
2024-10-13 16:02:24 +00:00
elif attention_backend == " fused_attn " :
2024-08-27 00:13:34 +00:00
self . l_padding = 128
2023-08-02 21:53:35 +00:00
if self . arch_type == " transformer " :
self . sin_emb = SinusoidalEmbedding ( d_model )
self . blocks = nn . ModuleList ( [ TransformerBlock (
d_model = d_model ,
n_heads = n_heads ,
2024-02-01 03:48:36 +00:00
p_dropout = p_dropout if training else 0.0 ,
2023-08-19 01:58:07 +00:00
causal = self . causal ,
2024-06-30 16:00:12 +00:00
norm_type = " ln " , # adaln
2023-08-02 21:53:35 +00:00
n_levels = self . n_resp_levels ,
) for _ in range ( n_layers ) ] )
2024-05-10 01:28:20 +00:00
elif self . arch_type in [ " mistral " , " mixtral " ] :
2024-02-01 03:48:36 +00:00
if n_experts < = 1 :
self . model = MistralModel ( MistralConfig (
2024-12-22 04:52:10 +00:00
vocab_size = n_vocab ,
2024-02-01 03:48:36 +00:00
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2024-02-01 03:48:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-30 15:37:33 +00:00
num_key_value_heads = self . config . experimental . kv_heads if self . config is not None and self . config . experimental . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2024-02-01 03:48:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
2024-06-30 15:37:33 +00:00
num_key_value_heads = self . config . experimental . kv_heads if self . config is not None and self . config . experimental . kv_heads > 0 else n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-02-01 03:48:36 +00:00
) )
2024-08-27 00:13:34 +00:00
if attention_backend not in HF_ATTENTIONS :
2024-08-05 03:03:22 +00:00
self . model = ml . replace_attention ( self . model , klass = MixtralAttention_Adapted , target = MixtralAttention , mode = attention_backend )
2024-05-10 04:25:44 +00:00
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2023-12-23 01:27:36 +00:00
elif self . arch_type == " llama " :
2024-11-10 18:19:48 +00:00
LlamaClass = LlamaModel_Adapted # if (self.layerskip or "len" in self.capabilities) else LlamaModel
2024-11-10 00:04:59 +00:00
2023-12-23 01:27:36 +00:00
if n_experts < = 1 :
2024-12-24 23:54:48 +00:00
config = LlamaConfig (
2024-12-22 04:52:10 +00:00
vocab_size = n_vocab ,
2023-12-23 01:27:36 +00:00
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2023-12-23 01:27:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2024-12-24 23:54:48 +00:00
)
print ( config )
self . model = LlamaClass ( config )
2024-11-10 00:04:59 +00:00
# replace with desired attention
2024-08-27 00:13:34 +00:00
if attention_backend not in HF_ATTENTIONS :
2024-08-05 03:03:22 +00:00
self . model = ml . replace_attention ( self . model , klass = LlamaAttention_Adapted , target = LlamaAttention , mode = attention_backend )
2023-12-23 01:27:36 +00:00
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
2024-06-09 22:11:38 +00:00
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
2023-12-23 01:27:36 +00:00
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-05-11 22:14:05 +00:00
attn_implementation = hf_attention ,
2024-06-04 02:28:49 +00:00
#gradient_checkpointing=self.gradient_checkpointing,
2023-12-23 01:27:36 +00:00
) )
2024-08-27 00:13:34 +00:00
if attention_backend not in HF_ATTENTIONS :
2024-08-05 03:03:22 +00:00
self . model = ml . replace_attention ( self . model , klass = MixtralAttention_Adapted , target = MixtralAttention , mode = attention_backend )
2024-05-10 04:15:52 +00:00
2024-10-31 18:24:48 +00:00
if self . layerskip :
self . model . layer_dropout_p = layerskip_p_max
self . model . early_exit_scale = layerskip_e_scale
self . model . early_exit_r = layerskip_r
2024-06-04 02:28:49 +00:00
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
2024-05-11 21:47:19 +00:00
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-12-26 03:20:32 +00:00
kwargs = dict (
2024-12-22 04:52:10 +00:00
vocab_size = n_vocab ,
2023-08-02 21:53:35 +00:00
decoder_embed_dim = d_model ,
2023-10-05 21:39:46 +00:00
decoder_value_embed_dim = d_model * 2 ,
2023-08-02 21:53:35 +00:00
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
2024-02-01 03:48:36 +00:00
dropout = p_dropout if training else 0.0 ,
2024-06-04 02:28:49 +00:00
checkpoint_activations = self . gradient_checkpointing ,
2023-10-05 21:39:46 +00:00
activation_fn = " gelu " ,
2024-05-10 01:28:20 +00:00
use_layernorm = self . version < 3 ,
use_biases = self . version < 3 ,
use_glu = self . version > = 3 ,
2023-08-02 21:53:35 +00:00
2024-06-07 01:51:31 +00:00
chunkwise_recurrent = self . causal and self . causal_size > 0 ,
recurrent_chunkwise_size = self . causal_size if self . causal else 0 ,
2023-08-02 21:53:35 +00:00
no_output_layer = True ,
decoder_normalize_before = True ,
2023-09-21 00:10:59 +00:00
2024-06-07 01:51:31 +00:00
rotary_embedding_base = 10000
2023-12-26 03:20:32 +00:00
)
if n_experts > 1 :
kwargs . update ( dict (
use_xmoe = True ,
moe_freq = 1 ,
moe_expert_count = n_experts ,
moe_gating_use_fp32 = False ,
) )
2023-12-21 00:45:58 +00:00
2023-12-26 03:20:32 +00:00
self . model = RetNetDecoder ( RetNetConfig ( * * kwargs ) )
2024-11-22 05:08:43 +00:00
elif self . arch_type in [ " mamba2 " ] :
self . model = Mamba2Model ( Mamba2Config (
2024-12-22 04:52:10 +00:00
vocab_size = n_vocab ,
2024-06-15 00:42:17 +00:00
hidden_size = d_model ,
2024-11-22 05:08:43 +00:00
expand = 2 ,
num_hidden_layers = n_layers * 2 ,
residual_in_fp32 = True ,
2024-06-15 00:42:17 +00:00
) )
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-11-22 05:08:43 +00:00
elif self . arch_type in [ " mamba " ] :
self . model = MambaModel ( MambaConfig (
2024-12-22 04:52:10 +00:00
vocab_size = n_vocab ,
2024-06-12 03:28:59 +00:00
hidden_size = d_model ,
2024-11-22 05:08:43 +00:00
expand = 2 ,
num_hidden_layers = n_layers * 2 ,
residual_in_fp32 = True ,
2024-06-12 03:28:59 +00:00
) )
if self . gradient_checkpointing and not self . model . gradient_checkpointing :
self . model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = dict (
use_reentrant = False
) )
2024-03-01 02:29:17 +00:00
else :
raise RuntimeError ( f ' Unknown arch specified: { self . arch_type } ' )
2023-08-04 01:26:36 +00:00
2024-06-15 00:42:17 +00:00
if hasattr ( self . model , " embeddings " ) :
del self . model . embeddings
2024-06-12 03:28:59 +00:00
if not split_classifiers :
2024-12-22 04:52:10 +00:00
self . classifier = nn . Linear ( d_model , n_vocab , bias = classifiers_bias )
2024-06-12 03:28:59 +00:00
self . classifiers = None
self . metrics = None
else :
self . classifier = None
2024-12-20 23:13:37 +00:00
self . classifiers = Classifiers ( classifier_l_tokens , d_model , l_names = classifier_l_names , bias = classifiers_bias )
2024-11-15 04:17:47 +00:00
self . metrics = Metrics ( classifier_l_tokens )
2023-08-02 21:53:35 +00:00
2024-07-19 20:33:31 +00:00
"""
if tie_classifier_to_embedding :
for i , proj in enumerate ( self . classifiers . proj ) :
self . classifiers . proj [ i ] . weight = self . resps_emb . embeddings [ i ] . weight
"""
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
def _forward (
2023-08-02 21:53:35 +00:00
self ,
2024-04-16 00:54:32 +00:00
inputs ,
mask = None ,
2024-11-22 19:44:43 +00:00
is_causal = None ,
2024-07-17 00:52:41 +00:00
position_ids = None ,
2024-10-12 02:18:26 +00:00
2024-04-16 00:54:32 +00:00
state = None ,
2024-11-02 02:30:06 +00:00
2024-11-02 16:49:05 +00:00
layer_skip_lambda = None ,
2024-11-02 02:30:06 +00:00
2024-10-12 02:18:26 +00:00
output_attentions = False ,
2024-10-31 01:05:45 +00:00
output_hidden_states = False ,
2023-08-02 21:53:35 +00:00
) :
2024-04-16 00:54:32 +00:00
x = inputs
2024-11-03 02:00:21 +00:00
m = mask #.squeeze(-1).int()
2024-10-12 02:18:26 +00:00
2023-12-23 01:27:36 +00:00
aux_loss = None
2024-10-12 02:18:26 +00:00
attentions = None
2024-10-31 01:05:45 +00:00
hidden_states = None
2024-07-16 23:23:13 +00:00
2023-12-23 01:27:36 +00:00
# HF transformer derived model
2024-04-16 00:54:32 +00:00
if self . arch_type in [ " llama " , " mistral " , " mixtral " ] :
2023-12-23 01:27:36 +00:00
kwargs = dict (
inputs_embeds = x ,
2024-11-22 05:08:43 +00:00
attention_mask = m ,
2024-02-01 03:48:36 +00:00
past_key_values = state ,
2024-07-17 00:52:41 +00:00
position_ids = position_ids ,
2024-11-02 16:49:05 +00:00
use_cache = False , # not self.training,
2024-10-12 02:18:26 +00:00
output_attentions = output_attentions ,
2024-10-31 01:05:45 +00:00
output_hidden_states = output_hidden_states ,
2024-10-12 02:18:26 +00:00
return_dict = True ,
2024-11-22 19:44:43 +00:00
is_causal = is_causal ,
2023-12-23 01:27:36 +00:00
)
2024-11-02 02:30:06 +00:00
2024-06-06 23:52:41 +00:00
if self . n_experts > 1 and self . training :
2023-12-23 01:27:36 +00:00
kwargs [ " output_router_logits " ] = True
2023-09-02 01:58:29 +00:00
2024-11-02 16:49:05 +00:00
if self . layerskip and layer_skip_lambda is not None :
kwargs [ " layer_skip_lambda " ] = layer_skip_lambda
2024-11-02 02:30:06 +00:00
2024-10-12 02:18:26 +00:00
output = self . model ( * * kwargs )
x = output [ " last_hidden_state " ]
2023-12-23 01:27:36 +00:00
2024-10-06 03:53:53 +00:00
# to-do: figure out why KV caching doesn't work
#if not self.training:
2024-02-01 03:48:36 +00:00
if state is not None :
2024-10-12 02:18:26 +00:00
state = output [ " past_key_values " ]
if output_attentions :
attentions = output [ " attentions " ]
2024-02-01 03:48:36 +00:00
2024-10-31 01:05:45 +00:00
if output_hidden_states :
hidden_states = output [ " hidden_states " ]
2024-06-06 23:52:41 +00:00
if self . n_experts > 1 and self . training :
2024-10-12 02:18:26 +00:00
router_logits = output [ " aux_loss " ]
2023-12-23 01:27:36 +00:00
aux_loss = self . model . config . router_aux_loss_coef * load_balancing_loss_func ( router_logits , self . model . config . num_local_experts , self . model . config . num_experts_per_tok )
elif self . arch_type == " transformer " :
2023-09-12 21:04:45 +00:00
# ensures we specify a quant_level for the transformer implementation's AdaLN
2023-09-07 14:14:03 +00:00
l = torch . zeros ( ( batch_size , ) , dtype = torch . int32 ) if quant_levels is None else quant_levels
l = l . to ( device )
2023-09-12 21:04:45 +00:00
# inject position information
x = self . sin_emb . add_pe ( x )
# pass our inputs through the transformer
2023-08-02 21:53:35 +00:00
for block in self . blocks :
2024-04-16 00:54:32 +00:00
x = block ( x , m , l )
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-09-12 21:04:45 +00:00
# pass our inputs through the RetNet
2023-12-23 01:27:36 +00:00
x , _ = self . model ( x , incremental_state = state , token_embeddings = x , features_only = True )
2023-12-26 03:20:32 +00:00
if _ is not None and " l_aux " in _ and self . n_experts > 1 :
2023-12-23 01:27:36 +00:00
aux_loss = torch . sum ( torch . stack ( [ t for t in _ [ " l_aux " ] if t is not None ] ) ) * 0.001
2024-06-05 03:41:22 +00:00
elif self . arch_type in [ " mamba " , " mamba2 " ] :
2024-06-15 00:42:17 +00:00
kwargs = dict (
inputs_embeds = x ,
2024-12-03 01:10:42 +00:00
attention_mask = m ,
2024-11-22 05:08:43 +00:00
#cache_params=state,
use_cache = False , # not self.training,
#position_ids=position_ids,
#output_attentions=output_attentions,
output_hidden_states = output_hidden_states ,
2024-06-15 00:42:17 +00:00
return_dict = True ,
)
2024-11-22 05:08:43 +00:00
output = self . model ( * * kwargs )
x = output [ " last_hidden_state " ]
2024-06-15 00:42:17 +00:00
if state is not None :
2024-11-22 05:08:43 +00:00
state = output [ " cache_params " ]
2024-06-12 03:28:59 +00:00
2024-11-22 05:08:43 +00:00
if output_attentions :
attentions = output [ " attentions " ]
if output_hidden_states :
hidden_states = output [ " hidden_states " ]
2024-04-14 18:12:50 +00:00
2024-10-31 01:05:45 +00:00
# process it into a format that I like
if output_hidden_states :
2024-10-31 18:24:48 +00:00
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
2024-11-01 17:50:37 +00:00
hidden_states = [ state for state in hidden_states [ 1 : ] ]
# apply normalization to these states (to-do: check if this matters)
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self . n_layers - 1 else self . model . norm ( output . hidden_states [ i ] ) for i , state in enumerate ( hidden_states ) ]
2024-10-31 01:05:45 +00:00
2024-12-07 05:53:46 +00:00
return Logits ( x , state , inputs , aux_loss , attentions , hidden_states , None )
2024-04-16 00:54:32 +00:00
2024-07-18 19:18:34 +00:00
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
2024-04-17 02:04:48 +00:00
def inputs (
2024-04-16 00:54:32 +00:00
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
2024-04-17 02:04:48 +00:00
2024-04-16 00:54:32 +00:00
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
2024-06-08 20:42:02 +00:00
len_list : list [ Tensor ] | None = None ,
task_list : list [ str ] | None = None ,
2024-11-09 04:46:26 +00:00
time_list : list [ Tensor ] | None = None ,
2024-06-05 04:23:31 +00:00
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None
2024-04-16 00:54:32 +00:00
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
2024-04-17 02:04:48 +00:00
inputs = [ [ ] for _ in range ( batch_size ) ]
for i in range ( batch_size ) :
2024-06-05 04:23:31 +00:00
quant_level = quant_levels [ i ] if quant_levels is not None else 0
2024-06-08 20:42:02 +00:00
task_type = task_list [ i ] if task_list is not None else " tts "
2024-11-10 18:19:48 +00:00
timestep = time_list [ i ] if time_list is not None else None
2024-11-13 19:31:17 +00:00
classifier_level = None
2024-06-08 20:42:02 +00:00
2024-07-18 19:18:34 +00:00
# insert task type as a string
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " task " , task_type ) )
2024-07-23 00:36:07 +00:00
# to-do: maybe not split the below blocks up
# might be beneficial in the event I need to use a difference sequence, such as STT tasks
2024-07-18 19:18:34 +00:00
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
2024-07-18 23:46:45 +00:00
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
2024-11-13 19:31:17 +00:00
if f ' < { task_type } > ' in get_task_symmap ( ) and task_type not in special_tasks :
2024-07-18 19:18:34 +00:00
# insert the text prompt
2024-07-23 00:36:07 +00:00
if text_list is not None and text_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
2024-07-18 19:18:34 +00:00
# insert lang token if we're trained for it
2024-07-23 00:36:07 +00:00
if " lang " in self . capabilities and lang_list is not None and lang_list [ i ] is not None :
2024-07-18 19:18:34 +00:00
inputs [ i ] . append ( ( " lang " , lang_list [ i ] ) )
# insert RVQ level guidance token if the model is versioned for it
2024-08-05 00:56:21 +00:00
if self . rvq_l_emb is not None and not self . interleave :
2024-08-04 03:10:21 +00:00
inputs [ i ] . append ( ( " quant_level " , torch . tensor ( [ quant_level ] , device = device , dtype = torch . int16 ) ) )
2024-11-13 19:31:17 +00:00
classifier_level = " AR:0:0 " if quant_level == 0 else f ' NAR: { quant_level - 1 } : { quant_level } '
2024-07-18 19:18:34 +00:00
# insert input audio prompt
2024-07-23 00:36:07 +00:00
if proms_list is not None and proms_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
2024-07-18 19:18:34 +00:00
# insert tone token if we're trained for it
2024-07-23 00:36:07 +00:00
if " tone " in self . capabilities and tone_list is not None and tone_list [ i ] is not None :
2024-07-18 19:18:34 +00:00
inputs [ i ] . append ( ( " tone " , tone_list [ i ] ) )
2024-11-10 00:12:54 +00:00
# insert timestep token
2024-11-10 18:19:48 +00:00
if timestep is not None :
2024-11-14 13:34:22 +00:00
# force set to use this classifier level
2024-11-15 04:17:47 +00:00
classifier_level = " NAR:0:0 "
# store timestep information
2024-11-17 23:04:07 +00:00
if self . masking_ratio in [ " random " , " rand " ] :
2024-11-13 15:07:10 +00:00
# cosine scheduled timestep => masking ratio
p = math . cos ( timestep * math . pi * 0.5 )
2024-11-17 23:04:07 +00:00
# I don't think is is necessary as the timestep is encoded in the sequence by the number of masked tokens, probably.
if self . inject_timestep_embedding :
inputs [ i ] . append ( ( " timestep " , torch . tensor ( [ timestep ] , device = device , dtype = self . time_emb . mlp [ 0 ] . weight . dtype ) ) )
else :
# a paper said to use a fixed masking ratio of 0.8 for training
# ...but I want to make it user adjustable
p = self . masking_ratio
2024-11-15 04:17:47 +00:00
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if self . training :
2024-11-13 15:07:10 +00:00
dropout_mask = _dropout_mask ( resps_list [ i ] , p )
2024-11-10 01:40:02 +00:00
inputs [ i ] . append ( ( " dropout_mask " , dropout_mask ) )
2024-11-15 04:17:47 +00:00
# insert the current output response
if resps_list is not None and resps_list [ i ] is not None :
inputs [ i ] . append ( ( " resp " , resps_list [ i ] ) )
2024-07-18 19:18:34 +00:00
2024-11-13 19:31:17 +00:00
inputs [ i ] . append ( ( " classifier_level " , classifier_level ) )
2024-07-18 19:18:34 +00:00
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
2024-06-08 20:42:02 +00:00
elif task_type == " len " :
# throw an error so we don't silently train without this
if self . len_emb is None :
raise Exception ( f " Requesting task ` { task_type } ` but corresponding embedding is not defined. " )
2024-07-18 19:18:34 +00:00
# insert the text prompt
2024-07-23 00:36:07 +00:00
if text_list is not None and text_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
2024-07-18 19:18:34 +00:00
# insert lang token if we're trained for it
2024-07-23 00:36:07 +00:00
if " lang " in self . capabilities and lang_list is not None and lang_list [ i ] is not None :
2024-07-18 19:18:34 +00:00
inputs [ i ] . append ( ( " lang " , lang_list [ i ] ) )
2024-06-08 20:42:02 +00:00
# technically will always be level 0 but for the sake of keeing the input formatting coherent...
if self . rvq_l_emb is not None :
2024-11-10 18:19:48 +00:00
inputs [ i ] . append ( ( " quant_level " , torch . tensor ( [ quant_level ] , device = device , dtype = torch . int16 ) ) )
2024-07-18 19:18:34 +00:00
# insert input audio prompt
2024-07-23 00:36:07 +00:00
if proms_list is not None and proms_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
2024-07-18 19:18:34 +00:00
# insert tone token if we're trained for it
2024-07-23 00:36:07 +00:00
if " tone " in self . capabilities and tone_list is not None and tone_list [ i ] is not None :
2024-07-18 19:18:34 +00:00
inputs [ i ] . append ( ( " tone " , tone_list [ i ] ) )
2024-06-08 20:42:02 +00:00
2024-07-18 19:18:34 +00:00
# insert output length tokens (if it exists)
2024-07-23 00:36:07 +00:00
if len_list is not None and len_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
inputs [ i ] . append ( ( " len " , len_list [ i ] ) )
# "encode" length to tokens for 0-9 + stop
2024-07-23 00:36:07 +00:00
elif resps_list is not None and resps_list [ i ] is not None :
2024-06-08 20:42:02 +00:00
# yes this could be encoded better
2024-08-04 03:10:21 +00:00
inputs [ i ] . append ( ( " len " , torch . tensor ( [ 0 ] + [ int ( i ) for i in str ( resps_list [ i ] . shape [ 0 ] ) ] + [ 10 ] , device = device , dtype = torch . int16 ) ) )
2024-11-13 19:31:17 +00:00
2024-11-15 04:17:47 +00:00
inputs [ i ] . append ( ( " classifier_level " , " len " ) )
2024-09-06 01:43:20 +00:00
# Speech-to-Text prediction task
# Sequence: <resp><sep><rvq lvl><sep><text>
elif task_type == " stt " :
# insert the input response
if resps_list is not None and resps_list [ i ] is not None :
inputs [ i ] . append ( ( " resp " , resps_list [ i ] ) )
# insert lang token if we're trained for it
if " lang " in self . capabilities and lang_list is not None and lang_list [ i ] is not None :
inputs [ i ] . append ( ( " lang " , lang_list [ i ] ) )
# insert RVQ level guidance token if the model is versioned for it
if self . rvq_l_emb is not None and not self . interleave :
inputs [ i ] . append ( ( " quant_level " , torch . tensor ( [ quant_level ] , device = device , dtype = torch . int16 ) ) )
# insert the output text prompt
if text_list is not None and text_list [ i ] is not None :
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
2024-11-13 19:31:17 +00:00
inputs [ i ] . append ( ( " classifier_level " , " stt " ) )
2024-07-18 22:16:32 +00:00
else :
raise Exception ( f ' Unrecognized task: { task_type } ' )
2024-04-17 02:04:48 +00:00
return inputs
2024-12-22 04:52:10 +00:00
def offset_inputs (
self ,
inputs : list ,
direction : int = 1 , # -1 to de-offset
) :
offsets = _get_offsets ( )
for batch_index , batch_input in enumerate ( inputs ) :
quant_level = None
classifier_level = None
# pre-iterate
for name , input in batch_input :
if name == " quant_level " :
quant_level = input
elif name == " classifier_level " :
classifier_level = input
for name , input in batch_input :
if not isinstance ( input , torch . Tensor ) :
continue
2024-12-22 21:05:45 +00:00
k = name
if name == " prom " :
k = f ' prom| { quant_level } '
elif name == " resp " :
k = f ' resps| { classifier_level } '
2024-12-22 04:52:10 +00:00
2024-12-22 21:05:45 +00:00
if k not in offsets :
continue
start , end = offsets [ k ]
2024-12-22 04:52:10 +00:00
for i , t in enumerate ( input ) :
2024-12-22 21:05:45 +00:00
input [ i ] + = start * direction
2024-12-22 04:52:10 +00:00
return inputs
2024-04-17 02:04:48 +00:00
def inputs_to_embeddings (
self ,
inputs : list ,
2024-06-08 01:46:22 +00:00
quant_levels : int | list [ int ] | Tensor | None = None
2024-04-17 02:04:48 +00:00
) :
2024-07-18 21:16:14 +00:00
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding ( input , quant_level ) :
2024-07-19 04:25:32 +00:00
if isinstance ( input , str ) :
2024-08-04 03:10:21 +00:00
return self . tasks_emb ( torch . tensor ( [ get_task_symmap ( ) [ f ' < { input } > ' ] ] , device = device , dtype = torch . int16 ) )
2024-07-18 21:16:14 +00:00
# get RVQ level 0, or up to targetted RVQ level inference
if self . version < = 4 :
2024-08-04 01:23:36 +00:00
return self . proms_emb (
input if quant_level == 0 else input [ : , : quant_level ]
)
2024-07-18 21:16:14 +00:00
2024-08-04 01:23:36 +00:00
return self . proms_emb (
input if input . dim ( ) == 1 else input [ : , : 1 if quant_level == 0 else quant_level ] ,
quant_level = 0 if quant_level == 0 else quant_level - 1 , # input is one below the target quant level
offset = 0 ,
)
2024-07-18 21:16:14 +00:00
2024-07-25 00:35:17 +00:00
# yuck
token_dropout_rate = self . config . experimental . token_dropout_rate if self . config else 0.0
token_dropout_rvq_levels = self . config . experimental . token_dropout_rvq_levels if self . config else None
if self . dropout_token is None or not self . training :
token_dropout_rate = 0.0
if not token_dropout_rvq_levels :
token_dropout_rvq_levels = [ 1 , self . resp_levels ]
2024-09-06 01:43:20 +00:00
2024-04-17 02:04:48 +00:00
x_list = [ ]
2024-05-29 00:29:54 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-04-17 02:04:48 +00:00
batch = [ ]
2024-06-05 04:23:31 +00:00
quant_level = quant_levels [ batch_index ] if quant_levels is not None else 0
2024-07-19 04:25:32 +00:00
2024-07-18 21:16:14 +00:00
task_type = " tts "
2024-08-04 01:23:36 +00:00
input_prom = None
2024-11-13 19:31:17 +00:00
classifier_level = None
2024-11-07 17:32:11 +00:00
dropout_mask = None
2024-11-09 04:46:26 +00:00
timestep = None
2024-11-07 17:32:11 +00:00
# pre-iterate
for name , input in batch_input :
2024-11-14 00:04:04 +00:00
if name == " classifier_level " :
2024-11-13 19:31:17 +00:00
classifier_level = input
2024-11-14 00:04:04 +00:00
elif name == " dropout_mask " :
dropout_mask = input
elif name == " timestep " :
timestep = input
2024-11-07 17:32:11 +00:00
2024-05-29 00:29:54 +00:00
for name , input in batch_input :
2024-06-08 20:42:02 +00:00
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
2024-04-17 02:04:48 +00:00
embedding = None
2024-11-15 04:17:47 +00:00
2024-06-30 00:46:11 +00:00
# is already an embedding
2024-06-08 20:42:02 +00:00
if name == " task " :
# noop
# *maybe* inject a token for specifying task type
2024-07-18 21:16:14 +00:00
task_type = input
2024-06-08 20:42:02 +00:00
continue
elif name == " text " :
2024-04-17 02:04:48 +00:00
embedding = self . text_emb ( input )
2024-07-19 04:25:32 +00:00
device = embedding . device
2024-06-08 20:42:02 +00:00
elif name == " quant_level " and self . rvq_l_emb is not None :
embedding = self . rvq_l_emb ( input )
2024-06-05 15:30:04 +00:00
elif name == " lang " and self . langs_emb is not None :
2024-04-17 02:04:48 +00:00
embedding = self . langs_emb ( input )
elif name == " prom " :
2024-07-18 21:16:14 +00:00
proms = [ input ] if isinstance ( input , torch . Tensor ) else input
2024-11-10 00:04:59 +00:00
"""
if proms is None :
continue
"""
2024-11-07 17:32:11 +00:00
# to-do: probably insert separators if task requires it?
2024-07-19 04:25:32 +00:00
embedding = torch . cat ( [ prompt_input_to_embedding ( input , quant_level ) for input in proms if input is not None ] )
2024-06-05 15:30:04 +00:00
elif name == " tone " and self . tones_emb is not None :
2024-04-17 02:04:48 +00:00
embedding = self . tones_emb ( input )
elif name == " resp " :
2024-08-05 00:56:21 +00:00
if self . interleave :
embeddings = [ self . resps_emb (
input [ : , : l + 1 ] ,
2024-11-13 19:31:17 +00:00
#offset = 0,
#quant_level = l,
name = ' AR:0:0 ' if l == 0 else f ' NAR: { l - 1 } : { l } ' ,
2024-08-05 00:56:21 +00:00
) for l in range ( input . shape [ - 1 ] ) ]
embedding = _interleave_sequence_reshape ( embeddings )
2024-11-08 03:19:14 +00:00
# if training NAR-len RVQ level 0
2024-11-10 18:19:48 +00:00
elif dropout_mask is not None :
2024-11-08 03:19:14 +00:00
embedding = self . resps_emb (
2024-11-10 00:04:59 +00:00
# if masked use masked token, else original token
2024-11-08 03:19:14 +00:00
torch . where ( dropout_mask , self . stop_token , input if input . dim ( ) == 1 else input [ : , 0 ] ) ,
2024-11-13 19:31:17 +00:00
#quant_level = 0,
name = classifier_level ,
2024-11-07 01:17:12 +00:00
)
2024-11-14 00:04:04 +00:00
# NAR-len
elif classifier_level == " NAR:0:0 " :
embedding = self . resps_emb (
input if input . dim ( ) == 1 else input [ : , 0 ] ,
#quant_level = 0,
name = classifier_level ,
)
2024-09-06 01:43:20 +00:00
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task :
2024-09-08 03:13:49 +00:00
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
2024-09-06 01:43:20 +00:00
embedding = sum ( [ self . resps_emb (
input [ : , : l + 1 ] ,
2024-11-13 20:27:15 +00:00
offset = 0 if l == 0 else 1 , # or maybe set to 1
quant_level = l ,
#name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}',
2024-09-08 03:13:49 +00:00
sums = False
2024-09-06 01:43:20 +00:00
) for l in range ( input . shape [ - 1 ] - 1 ) ] )
2024-06-08 20:42:02 +00:00
else :
# get RVQ level 0, or up to targetted RVQ level inference
2024-06-09 01:30:15 +00:00
if self . version < = 4 :
2024-07-18 19:18:34 +00:00
embedding = self . resps_emb (
input if quant_level == 0 else input [ : , : quant_level ] ,
quant_level
)
2024-06-09 01:30:15 +00:00
else :
2024-11-13 19:31:17 +00:00
"""
2024-08-05 00:56:21 +00:00
offset = 0
2024-11-07 15:10:18 +00:00
if " nar " not in self . capabilities :
2024-08-05 00:56:21 +00:00
offset = 0
elif quant_level > 0 :
offset = 1
2024-06-09 01:30:15 +00:00
embedding = self . resps_emb (
input if input . dim ( ) == 1 or quant_level == 0 else input [ : , : quant_level ] ,
2024-08-05 00:56:21 +00:00
offset = offset ,
2024-08-04 01:23:36 +00:00
quant_level = 0 if quant_level == 0 else quant_level - 1 , # input is one below the target quant level
2024-06-09 01:30:15 +00:00
)
2024-11-13 19:31:17 +00:00
"""
embedding = self . resps_emb (
input if input . dim ( ) == 1 or quant_level == 0 else input [ : , : quant_level ] ,
2024-11-15 04:17:47 +00:00
#offset = 0 if classifier_level.startswith("AR:") else 1,
name = classifier_level ,
2024-11-13 19:31:17 +00:00
quant_level = 0 if quant_level == 0 else quant_level - 1 , # input is one below the target quant level
)
2024-07-25 00:35:17 +00:00
# apply token dropout
if token_dropout_rate > 0.0 and ( token_dropout_rvq_levels [ 0 ] < = quant_level and quant_level < = token_dropout_rvq_levels [ 1 ] ) :
steps = embedding . shape [ 0 ] - ( 1 if quant_level == 0 else 0 ) # do not mess with stop token
for i in range ( steps ) :
if random . random ( ) > token_dropout_rate :
continue
embedding [ i ] = self . dropout_token
2024-11-14 13:34:22 +00:00
elif name == " timestep " and self . time_emb is not None :
2024-11-10 00:04:59 +00:00
embedding = self . time_emb ( input )
2024-06-08 20:42:02 +00:00
elif name == " len " and self . len_emb is not None :
embedding = self . len_emb ( input )
2024-04-17 02:04:48 +00:00
else :
2024-06-08 20:42:02 +00:00
# should probably raise an exception so things aren't processed silently
2024-04-17 02:04:48 +00:00
continue
2024-11-10 00:04:59 +00:00
2024-04-17 02:04:48 +00:00
batch . append ( embedding )
2024-07-18 22:16:32 +00:00
2024-04-17 02:04:48 +00:00
x_list . append ( _join ( batch , self . sep ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return x_list
2024-11-10 00:04:59 +00:00
# get an attribute from a given input list
def get_input (
self ,
inputs ,
name ,
at = None ,
) :
2024-11-10 18:19:48 +00:00
find_all = at is None
res = [ ] if at is None else None
2024-11-10 00:04:59 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-11-10 18:19:48 +00:00
if not find_all and batch_index != at :
2024-11-10 00:04:59 +00:00
continue
for n , input in batch_input :
2024-11-10 18:19:48 +00:00
if n != name :
continue
if not find_all :
2024-11-10 00:04:59 +00:00
return input
2024-11-10 18:19:48 +00:00
res . append ( input )
return res
2024-11-10 00:04:59 +00:00
2024-07-18 19:18:34 +00:00
# creates position ids from a given input list
# if not unified_position_ids, then each input segment will have its own sequence
2024-07-17 00:52:41 +00:00
def inputs_to_position_ids (
self ,
inputs : list ,
mask : Tensor ,
) :
2024-08-04 03:10:21 +00:00
device = mask . device
2024-07-17 00:52:41 +00:00
# shamelessly grabbed from modeling_llama.py
ids = mask . long ( ) . cumsum ( - 1 ) - 1
ids . masked_fill_ ( mask == 0 , 1 )
2024-08-01 01:35:09 +00:00
2024-07-17 00:52:41 +00:00
# there's a better way
if not self . unified_position_ids :
x_list = [ ]
2024-07-19 04:25:32 +00:00
2024-11-15 04:17:47 +00:00
def get_input_token_length ( name , input , task ) :
2024-07-19 04:25:32 +00:00
# task token
if isinstance ( input , str ) :
return 1
# list of tokens
if not isinstance ( input , torch . Tensor ) :
2024-11-15 04:17:47 +00:00
return sum ( [ i . shape [ 0 ] for i in input if isinstance ( i , torch . Tensor ) ] )
2024-07-19 04:25:32 +00:00
2024-08-05 00:56:21 +00:00
# interleaved model
if self . interleave and name == " resp " :
return input . shape [ 0 ] * input . shape [ 1 ]
2024-08-03 03:25:49 +00:00
# ending input will not have a separator later
2024-11-15 04:17:47 +00:00
return input . shape [ 0 ]
2024-07-19 04:25:32 +00:00
2024-07-17 00:52:41 +00:00
for batch_index , batch_input in enumerate ( inputs ) :
2024-11-15 04:17:47 +00:00
# pre-iterate
task = " tts "
for name , input in batch_input :
if name == " task " :
task = input
2024-07-19 04:25:32 +00:00
batch = torch . cat ( [
2024-11-15 04:17:47 +00:00
torch . tensor ( [ * range ( get_input_token_length ( name , input , task ) + ( 1 if name != task_outputs . get ( task , name ) else 0 ) ) ] , device = device , dtype = torch . int32 )
for name , input in batch_input if name not in non_tokened_names
2024-07-19 04:25:32 +00:00
] )
2024-07-17 00:52:41 +00:00
delta = ids [ batch_index ] . shape [ 0 ] - batch . shape [ 0 ]
if delta > 0 :
2024-08-04 13:18:57 +00:00
batch = torch . cat ( [ batch , torch . tensor ( [ 1 ] * delta , device = device , dtype = torch . int32 ) ] )
2024-07-17 00:52:41 +00:00
x_list . append ( batch )
ids = torch . stack ( x_list )
2024-08-04 03:10:21 +00:00
return ids . to ( device = device , dtype = torch . int32 )
2024-07-17 00:52:41 +00:00
2024-05-29 00:29:54 +00:00
def calc_loss (
2024-04-17 02:04:48 +00:00
self ,
inputs : list ,
2024-05-29 00:29:54 +00:00
logits ,
2024-11-13 19:31:17 +00:00
quant_levels : list [ int ] | None = None ,
2024-12-06 05:05:52 +00:00
compute_hard_loss = True ,
compute_acc = True ,
2024-04-17 02:04:48 +00:00
) :
2024-11-15 04:17:47 +00:00
loss = { }
stats = { }
2024-10-31 01:05:45 +00:00
2024-07-19 04:25:32 +00:00
device = logits [ 0 ] . device
2024-11-10 01:04:13 +00:00
batch_size = len ( logits )
2024-11-13 19:31:17 +00:00
classifier_levels = self . get_input ( inputs , " classifier_level " )
2024-07-19 20:33:31 +00:00
2024-07-18 21:16:14 +00:00
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token ( input , quant_level ) :
2024-11-15 04:17:47 +00:00
"""
if isinstance ( input , str ) :
return torch . tensor ( [ self . ignore_index ] , device = device , dtype = torch . int16 )
return torch . tensor ( [ self . ignore_index ] * input . shape [ 0 ] , device = device , dtype = torch . int16 )
"""
2024-07-19 04:25:32 +00:00
if isinstance ( input , str ) :
2024-08-04 13:18:57 +00:00
return torch . tensor ( [ get_task_symmap ( ) [ f ' < { input } > ' ] ] , device = device , dtype = torch . int16 )
2024-07-18 21:16:14 +00:00
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self . version < 4 or ( self . version > = 5 and self . config and self . config . experimental . audio_embedding_sums ) :
return torch . full_like ( input [ . . . , 0 ] , self . ignore_index )
return input if input . dim ( ) == 1 else input [ : , quant_level ]
2024-11-15 04:17:47 +00:00
for batch_index , batch in enumerate ( inputs ) :
quant_level = quant_levels [ batch_index ]
target = [ ]
causal = True
task_type = " tts "
dropout_mask = None
classifier_level = None
2024-12-06 05:05:52 +00:00
output_len = 0
2024-07-18 21:16:14 +00:00
2024-11-15 04:17:47 +00:00
for name , input in batch :
if name == " task " :
task_type = input
elif name == " dropout_mask " :
dropout_mask = input
elif name == " classifier_level " :
classifier_level = input
2024-11-13 15:54:20 +00:00
2024-11-15 04:17:47 +00:00
# autoregressive, causal
if classifier_level . startswith ( " AR: " ) :
causal = True
# nonautoregressive, parallel
elif classifier_level . startswith ( " NAR: " ) :
causal = False
2024-06-06 02:02:05 +00:00
2024-11-13 16:17:52 +00:00
it = 0
for name , input in batch :
2024-11-15 04:17:47 +00:00
token = None
ignored = False
2024-08-03 13:40:39 +00:00
2024-11-15 04:17:47 +00:00
# non-tokened tasks
if name in non_tokened_names :
2024-11-13 16:17:52 +00:00
continue
2024-11-15 04:17:47 +00:00
# prom can either be a tensor itself or a list of tensors and strings
if name == " prom " :
# expand to list if not a list
proms = [ input ] if isinstance ( input , torch . Tensor ) else input
# iterate over the list to inject their tokens
token = torch . cat ( [ prompt_input_to_token ( input , quant_level ) for input in proms if input is not None ] )
elif name == " resp " :
# mask found, apply it
2024-11-13 16:17:52 +00:00
if dropout_mask is not None :
# if mask use original token, else ignore
2024-11-15 04:17:47 +00:00
token = torch . where ( dropout_mask , input if input . dim ( ) == 1 else input [ : , 0 ] , self . ignore_index )
# flatten
2024-11-13 16:17:52 +00:00
elif self . interleave :
2024-11-15 04:17:47 +00:00
token = _interleave_sequence_flatten ( [ input [ : , l ] for l in range ( input . shape [ - 1 ] ) ] )
# use resps as-is
2024-09-06 01:43:20 +00:00
else :
2024-11-15 04:17:47 +00:00
token = input if input . dim ( ) == 1 else input [ : , quant_level ]
# not a special input, inject as-is
else :
token = input
2024-05-29 00:29:54 +00:00
2024-11-15 04:17:47 +00:00
if not isinstance ( token , torch . Tensor ) :
continue
2024-12-22 04:52:10 +00:00
# offset to flattened vocab ranges
2024-12-22 22:15:24 +00:00
"""
2024-12-22 04:52:10 +00:00
if self . classifier is not None :
offsets = _get_offsets ( )
2024-12-22 21:05:45 +00:00
k = name
if name == " stt " :
k = " text "
if name == " prom " :
k = f ' prom| { quant_level } '
elif name == " resp " :
k = f ' resps| { classifier_level } '
if k in offsets :
start , end = offsets [ k ]
2024-12-22 04:52:10 +00:00
for i , t in enumerate ( token ) :
if t == self . ignore_index :
continue
2024-12-22 21:05:45 +00:00
token [ i ] + = start
2024-12-22 22:15:24 +00:00
"""
2024-12-22 04:52:10 +00:00
2024-11-15 04:17:47 +00:00
if token . is_floating_point ( ) :
ignored = True
2024-06-05 04:23:31 +00:00
2024-11-15 04:17:47 +00:00
# grab range of our logits for later
seq_len = token . shape [ 0 ]
start , end = it , it + seq_len
2024-05-29 00:29:54 +00:00
it + = seq_len + 1 # +1 to incorporate the separator
2024-11-15 04:17:47 +00:00
# deduce if a name for a task is an input or output
2024-12-06 05:05:52 +00:00
if name != task_outputs . get ( task_type , name ) :
if self . ignore_inputs_for_loss :
ignored = True
else :
output_len = seq_len
2024-05-29 00:29:54 +00:00
2024-11-15 04:17:47 +00:00
if ignored :
# pruned
if self . config . loss_factors :
continue
# fill with ignored out tensor
token = torch . tensor ( [ self . ignore_index ] * input . shape [ 0 ] , device = device , dtype = torch . int16 )
# perform loss calculation on the individual piece
if self . config . loss_factors :
loss_factor = self . loss_factor ( name )
if loss_factor == 0.0 :
continue
2024-05-29 00:29:54 +00:00
2024-11-15 04:17:47 +00:00
logit = logits [ batch_index ] [ start : end ]
if causal and seq_len > 1 :
l = self . causal_size
logit = logit [ . . . , : - l , : ]
token = token [ . . . , l : ] # shift sequence to the right by one (or causal chunk size)
2024-12-06 05:05:52 +00:00
if compute_hard_loss :
nll = F . cross_entropy ( logit , token . long ( ) , ignore_index = self . ignore_index ) * loss_factor
if f ' { name } .nll ' not in loss :
loss [ f ' { name } .nll ' ] = [ ]
loss [ f ' { name } .nll ' ] . append ( nll )
2024-11-15 04:17:47 +00:00
2024-12-06 05:05:52 +00:00
if compute_acc :
if self . metrics is not None :
metrics = self . metrics . calc_accuracy ( [ logit ] , [ token ] , self . classifiers . indices ( [ classifier_level ] ) )
else :
2024-12-23 02:11:31 +00:00
accuracy_metric = MulticlassAccuracy (
logit . shape [ - 1 ] ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = - 100
) . to ( logit . device )
metrics = accuracy_metric ( logit , token )
2024-12-06 05:05:52 +00:00
if f ' { name } .acc ' not in stats :
stats [ f ' { name } .acc ' ] = [ ]
stats [ f ' { name } .acc ' ] . append ( metrics )
2024-11-15 04:17:47 +00:00
# add to list
else :
target . append ( token )
# perofrm loss calculation on the entire sequence
if not self . config . loss_factors :
2024-12-22 22:15:24 +00:00
target = _join ( target , torch . tensor ( self . ignore_index , device = target [ - 1 ] . device ) )
2024-11-15 04:17:47 +00:00
logit = logits [ batch_index ]
2024-06-05 04:48:51 +00:00
2024-11-15 04:17:47 +00:00
# shift if causal
if causal :
l = self . causal_size
logit = logit [ . . . , : - l , : ] # shift the target so that token n...
target = target [ . . . , l : ] # ...predicts token n + 1
2024-12-06 05:05:52 +00:00
if compute_hard_loss :
nll = F . cross_entropy ( logit , target , ignore_index = self . ignore_index )
if ' nll ' not in loss :
loss [ ' nll ' ] = [ ]
loss [ " nll " ] . append ( nll )
2024-05-27 13:43:00 +00:00
2024-12-06 05:05:52 +00:00
if compute_acc :
if self . metrics is not None :
metrics = self . metrics . calc_accuracy ( [ logit ] , [ target ] , self . classifiers . indices ( [ classifier_level ] ) )
else :
2024-12-23 02:11:31 +00:00
accuracy_metric = MulticlassAccuracy (
logit . shape [ - 1 ] ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = - 100
) . to ( logit . device )
metrics = accuracy_metric ( logit , target )
2024-11-15 04:17:47 +00:00
2024-12-06 05:05:52 +00:00
if ' acc ' not in stats :
stats [ ' acc ' ] = [ ]
stats [ " acc " ] . append ( metrics )
2024-11-15 04:17:47 +00:00
# average
loss = { name : sum ( loss [ name ] ) / len ( loss [ name ] ) for name in loss . keys ( ) }
stats = { name : sum ( stats [ name ] ) / len ( stats [ name ] ) for name in stats . keys ( ) }
2024-10-31 01:05:45 +00:00
return LossStats ( loss , stats )
2024-05-19 16:23:56 +00:00
2024-04-17 02:04:48 +00:00
def forward (
self ,
inputs : list ,
2024-11-13 19:31:17 +00:00
quant_levels : list [ int ] | None = None ,
2024-04-17 02:04:48 +00:00
state : dict | list | None = None ,
2024-11-02 16:49:05 +00:00
layer_skip_variables : dict | None = None ,
2024-11-02 02:30:06 +00:00
output_attentions : bool = False ,
output_hidden_states : bool = False ,
2024-04-17 02:04:48 +00:00
) :
2024-11-02 16:49:05 +00:00
# return early if it's "good" enough"
2024-11-13 19:31:17 +00:00
# lambda because we need to capture the classifier_levels and mask
2024-11-04 00:31:28 +00:00
exited_layer = self . n_layers
2024-11-02 16:49:05 +00:00
def layer_skip_lambda ( layer , logits ) :
2024-11-04 00:31:28 +00:00
nonlocal exited_layer
2024-11-02 16:49:05 +00:00
kwargs = {
2024-11-04 00:31:28 +00:00
" entropy_threshold " : 0.05 ,
" varentropy_threshold " : 0.05 ,
" min_layer " : self . n_layers / / 2 ,
2024-11-02 16:49:05 +00:00
" max_layer " : self . n_layers ,
}
kwargs . update ( layer_skip_variables )
# don't bother on early layers
if layer < kwargs [ " min_layer " ] :
return False
# bail if we want to force early layers
if kwargs [ " max_layer " ] < layer :
return True
# hidden states aren't normalized
x = self . model . norm ( logits )
# output projection layer with masking
if self . classifier is not None :
2024-11-03 15:58:29 +00:00
x = self . classifier ( x ) # * m
2024-11-02 16:49:05 +00:00
elif self . classifiers is not None :
2024-11-13 19:31:17 +00:00
logits = self . classifiers ( logits , levels = classifier_levels ) # * m
2024-11-02 16:49:05 +00:00
# calculate metrics
metrics = calculate_entropix_metrics ( logits )
# exit early if "good enough""
2024-11-04 00:31:28 +00:00
early = metrics [ " logits_entropy " ] < = kwargs [ " entropy_threshold " ] and metrics [ " logits_varentropy " ] < = kwargs [ " varentropy_threshold " ]
if early :
exited_layer = layer
return early
2024-11-02 16:49:05 +00:00
2024-11-12 22:41:58 +00:00
# derive quant levels from inputs if not provided
if quant_levels is None :
2024-12-07 05:53:46 +00:00
quant_levels = [ x . item ( ) for x in self . get_input ( inputs , " quant_level " ) ]
2024-11-12 22:41:58 +00:00
2024-04-17 02:04:48 +00:00
x_list = self . inputs_to_embeddings ( inputs , quant_levels )
2024-11-03 02:00:21 +00:00
x , mask = list_to_tensor ( x_list )
2024-04-17 02:04:48 +00:00
2024-06-06 23:52:41 +00:00
training = self . training
2024-12-06 05:05:52 +00:00
teaching = self . teaching
2024-04-17 02:04:48 +00:00
device = x . device
batch_size = len ( x_list )
2024-06-12 03:28:59 +00:00
# pure AR
if quant_levels is None :
quant_levels = [ 0 for _ in range ( batch_size ) ]
2024-10-31 01:05:45 +00:00
2024-11-02 16:49:05 +00:00
# we only need hidden states if we're training with layerskip
if self . layerskip and training :
2024-10-31 01:05:45 +00:00
output_hidden_states = True
2024-04-16 00:54:32 +00:00
# pad our input and mask, but retain the original length by doing it after
if self . l_padding and x . shape [ 1 ] % self . l_padding != 0 :
# pad input
shape = list ( x . shape )
shape [ 1 ] = self . l_padding - shape [ 1 ] % self . l_padding
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
x = torch . cat ( [ x , padding ] , dim = 1 )
# pad mask
shape [ 2 ] = 1
2024-11-22 05:08:43 +00:00
padding = torch . zeros ( shape [ : 2 ] , dtype = x . dtype , device = x . device )
2024-11-03 02:00:21 +00:00
mask = torch . cat ( [ mask , padding ] , dim = 1 )
2024-11-22 05:08:43 +00:00
m = mask . unsqueeze ( dim = - 1 )
2024-04-16 00:54:32 +00:00
2024-07-17 00:52:41 +00:00
# needs to be done here as we still have our raw inputs
2024-11-03 02:00:21 +00:00
position_ids = self . inputs_to_position_ids ( inputs , mask = mask ) if not self . unified_position_ids else None
2024-11-13 19:31:17 +00:00
classifier_levels = self . get_input ( inputs , name = " classifier_level " )
2024-11-22 19:44:43 +00:00
casual_levels = [ " AR:0:0 " , " stt " , " len " ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
2024-12-12 04:45:38 +00:00
is_causal = [ l in casual_levels for l in classifier_levels ] if self . noncausal_masks else [ True for l in classifier_levels ]
2024-11-13 17:38:58 +00:00
2024-10-12 02:18:26 +00:00
output = self . _forward (
2024-04-16 00:54:32 +00:00
inputs = x ,
2024-11-03 02:00:21 +00:00
mask = mask ,
2024-04-16 00:54:32 +00:00
state = state ,
2024-11-22 19:44:43 +00:00
is_causal = is_causal ,
2024-07-17 00:52:41 +00:00
position_ids = position_ids ,
2024-10-12 02:18:26 +00:00
output_attentions = output_attentions ,
2024-10-31 01:05:45 +00:00
output_hidden_states = output_hidden_states ,
2024-11-02 16:49:05 +00:00
layer_skip_lambda = layer_skip_lambda if self . layerskip and layer_skip_variables else None ,
2024-04-16 00:54:32 +00:00
)
2023-08-02 21:53:35 +00:00
2024-10-12 02:18:26 +00:00
logits = output . logits
2024-10-31 01:05:45 +00:00
hidden_states = output . hidden_states
2024-10-12 02:18:26 +00:00
2024-11-04 00:31:28 +00:00
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
2024-11-02 16:49:05 +00:00
if self . classifier is not None :
2024-11-03 15:58:29 +00:00
logits = self . classifier ( logits ) # * m
2024-11-02 16:49:05 +00:00
if output . hidden_states :
for i , state in enumerate ( hidden_states ) :
2024-11-03 15:58:29 +00:00
hidden_states [ i ] = self . classifier ( hidden_states [ i ] ) # * m
2024-09-06 20:13:04 +00:00
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
2024-11-03 02:00:21 +00:00
elif self . classifiers is not None :
2024-11-13 19:31:17 +00:00
logits = self . classifiers ( logits , levels = classifier_levels ) # * m
2024-06-12 03:28:59 +00:00
2024-10-31 01:05:45 +00:00
if hidden_states is not None :
2024-11-01 17:50:37 +00:00
for i , state in enumerate ( hidden_states ) :
2024-11-13 19:31:17 +00:00
hidden_states [ i ] = self . classifiers ( hidden_states [ i ] , levels = classifier_levels ) # * m
2024-10-31 01:05:45 +00:00
2023-08-02 21:53:35 +00:00
# Remove padding
2024-10-12 02:18:26 +00:00
logits = [ hi [ : li ] for hi , li in zip ( logits , map ( len , x_list ) ) ]
2024-11-01 17:50:37 +00:00
if hidden_states is not None :
for i , state in enumerate ( hidden_states ) :
hidden_states [ i ] = [ hi [ : li ] for hi , li in zip ( hidden_states [ i ] , map ( len , x_list ) ) ]
2024-04-14 18:12:50 +00:00
2024-12-22 22:15:24 +00:00
# de-offset if needed
if self . classifier is not None :
offsets = _get_offsets ( )
for batch_index , classifier_level in enumerate ( classifier_levels ) :
if classifier_level == " stt " :
k = " text "
elif classifier_level == " len " :
k = " len "
else :
k = f ' resps| { classifier_level } '
if k not in offsets :
continue
start , end = offsets [ k ]
logits [ batch_index ] = logits [ batch_index ] [ : , start : end ]
2024-12-06 05:05:52 +00:00
if not training :
loss = None
stats = None
self . loss = None
self . stats = None
2024-12-22 21:05:45 +00:00
# compute loss if the target is given
2024-12-06 05:05:52 +00:00
else :
2024-12-07 03:55:20 +00:00
loss , stats = self . calc_loss ( inputs = inputs , logits = logits , quant_levels = quant_levels )
2024-11-01 17:50:37 +00:00
# compute it as an aux-loss
if self . layerskip :
early_exit_loss = { }
if not hasattr ( self , " training_steps " ) :
self . training_steps = 0
2024-10-31 01:05:45 +00:00
2024-11-01 17:50:37 +00:00
for i , state in enumerate ( hidden_states ) :
2024-10-31 01:05:45 +00:00
loss , stats = self . calc_loss ( inputs = inputs , logits = hidden_states [ i ] , quant_levels = quant_levels )
for k , v in loss . items ( ) :
2024-11-01 17:50:37 +00:00
K = f ' early_exit. { k } '
if K not in early_exit_loss :
early_exit_loss [ K ] = [ ]
early_exit_loss [ K ] . append ( v )
for k , v in early_exit_loss . items ( ) :
loss [ k ] = self . model . early_exit_loss ( losses = v , t = self . training_steps )
# to-do: instead make the cirriculum rely on samples processed instead of steps
self . training_steps + = 1 # batch_size
2024-10-31 01:05:45 +00:00
2024-05-19 16:23:56 +00:00
# include any additional losses (for example: MoE router)
2024-12-06 05:05:52 +00:00
if output . loss is not None :
loss [ " aux_loss " ] = output . loss
2024-10-31 01:05:45 +00:00
self . loss = loss
self . stats = stats
2023-09-09 01:30:54 +00:00
2024-10-12 02:18:26 +00:00
# rewrap, because we're modifying the logits here
2024-12-07 05:53:46 +00:00
return Logits ( logits , output . state , inputs , loss , output . attentions , hidden_states , exited_layer )
2023-09-13 02:28:07 +00:00
def sample (
self ,
2024-07-30 00:15:07 +00:00
logits : list [ Tensor ] , # logit scores
2024-10-05 03:18:20 +00:00
prev_list : list [ Tensor ] | None = None , # previous tokens
2024-11-13 19:31:17 +00:00
quant_levels : list [ int ] | None = None , # to-do: derive this from the prev_list
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
) :
# yikes
temperature = sampling_kwargs . get ( " temperature " , 1.0 )
min_temperature = sampling_kwargs . get ( " min_temperature " , - 1.0 )
top_k = sampling_kwargs . get ( " top_k " , - 100 )
top_p = sampling_kwargs . get ( " top_p " , 1.0 )
min_p = sampling_kwargs . get ( " min_p " , 0.0 )
2024-07-30 00:15:07 +00:00
# repetition penalty parameters
2024-11-12 02:21:16 +00:00
repetition_penalty = sampling_kwargs . get ( " repetition_penalty " , 1.0 )
repetition_penalty_decay = sampling_kwargs . get ( " repetition_penalty_decay " , 0.0 )
2024-07-30 00:15:07 +00:00
# length penalty parameters
2024-11-12 02:21:16 +00:00
length_penalty = sampling_kwargs . get ( " length_penalty " , 0.0 )
2024-07-30 00:15:07 +00:00
# beam sampling parameters
2024-11-12 02:21:16 +00:00
beam_width = sampling_kwargs . get ( " beam_width " , 0 )
2024-07-30 00:15:07 +00:00
# mirostat sampling parameters
2024-11-12 02:21:16 +00:00
mirostat = sampling_kwargs . get ( " mirostat " , None )
2024-07-30 00:15:07 +00:00
# DRY sampling parameters
2024-11-12 02:21:16 +00:00
dry_multiplier = sampling_kwargs . get ( " dry_multiplier " , 0.0 )
dry_base = sampling_kwargs . get ( " dry_base " , 1.75 )
dry_allowed_length = sampling_kwargs . get ( " dry_allowed_length " , 2 )
2024-11-13 04:30:09 +00:00
#
top_no = sampling_kwargs . get ( " top_no " , 1.0 )
#
2024-11-12 02:21:16 +00:00
attentions = sampling_kwargs . get ( " attentions " , None )
2024-11-02 16:49:05 +00:00
batch_size = len ( logits )
2023-10-10 22:02:33 +00:00
if min_temperature < 0 :
min_temperature = temperature
2024-06-06 00:50:06 +00:00
2024-11-12 03:40:19 +00:00
# pick last RVQ level
if prev_list is not None :
prev_list = [ prevs if prevs . dim ( ) == 1 else prevs [ : , - 1 ] for prevs in prev_list ]
2024-10-12 02:18:26 +00:00
scores = None
entropy = None
2024-11-03 15:58:29 +00:00
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
#logits = [ logit.to(device="cpu") for logit in logits ]
2024-10-12 02:18:26 +00:00
2024-10-13 17:01:12 +00:00
# (AR) entropix sampling
# we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token)
if attentions is not None and quant_levels is None :
# move to CPU for speedups
seq_lens = [ logit . shape [ 0 ] for logit in logits ]
attentions = torch . stack ( attentions , dim = 1 ) . to ( device = " cpu " ) # ( batch, layer, heads, seq_len, seq_len )
res = [ sample_entropix (
logit [ : seq_lens [ batch ] , : ] , # ( seq_len, vocab )
attentions [ batch , : , : , : seq_lens [ batch ] , : seq_lens [ batch ] ] , # (layer, heads, seq_len, seq_len )
temperature ,
top_k ,
top_p ,
min_p ,
) for batch , logit in enumerate ( logits ) ]
if res :
2024-11-08 19:34:39 +00:00
return Sampled ( [ r [ 0 ] for r in res ] , logits , scores , [ r [ 1 ] for r in res ] )
2024-11-02 16:49:05 +00:00
"""
elif quant_levels is None :
seq_lens = [ logit . shape [ 0 ] for logit in logits ]
entropy = [ calculate_entropix_metrics (
logit [ : seq_lens [ batch ] , : ] , # ( seq_len, vocab )
#attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
) for batch , logit in enumerate ( logits ) ]
"""
2024-10-13 17:01:12 +00:00
2023-09-09 01:30:54 +00:00
# (NAR) return the entire generated response
2024-07-30 00:15:07 +00:00
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None : # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
2024-10-13 04:53:13 +00:00
seq_lens = map ( len , prev_list )
logits = [ logit [ - l : ] for logit , l in zip ( logits , seq_lens ) ]
2023-09-09 01:30:54 +00:00
# (AR chunkwise) return the last chunkwise piece
2024-06-07 01:51:31 +00:00
elif self . causal :
2024-10-13 04:53:13 +00:00
seq_lens = [ logit . shape [ 0 ] - self . causal_size for logit in logits ]
2024-06-07 01:51:31 +00:00
logits = [ logit [ - self . causal_size : ] for logit in logits ]
2023-09-09 01:30:54 +00:00
2024-08-03 13:40:39 +00:00
# (NAR) disable stop token
if quant_levels is not None and " ar " in self . capabilities :
2024-09-06 01:43:20 +00:00
logits = [ ban_tokens ( logit , tokens = [ self . stop_token ] ) for logit , l in zip ( logits , map ( len , prev_list ) ) ]
2024-08-03 13:40:39 +00:00
# (AR-len) disable extraneous tokens
2024-11-07 15:10:18 +00:00
"""
2024-08-03 13:40:39 +00:00
if quant_levels is None and " len " in self . capabilities :
2024-09-06 01:43:20 +00:00
logits = [ ban_tokens ( logit , tokens = [ * range ( 11 , logit . shape [ - 1 ] ) ] ) for logit , l in zip ( logits , map ( len , prev_list ) ) ]
2024-11-07 15:10:18 +00:00
"""
2023-10-09 19:46:17 +00:00
2024-07-31 01:53:51 +00:00
# perform repetition penalizing
2024-11-09 04:05:41 +00:00
if prev_list is not None and repetition_penalty != 1.0 :
2024-11-12 22:41:58 +00:00
logits = [ reptition_penalize ( logit , previous = prevs , factor = repetition_penalty , decay = repetition_penalty_decay ) for logit , prevs in zip ( logits , prev_list ) ]
2024-07-31 01:53:51 +00:00
2023-09-09 01:43:36 +00:00
# (AR) perform length penalizing
2024-10-05 03:30:47 +00:00
if quant_levels is None and self . causal and prev_list is not None and length_penalty != 0.0 :
2024-09-06 01:43:20 +00:00
logits = [ length_penalize ( logit , length = l + 1 , factor = length_penalty , token = self . stop_token ) for logit , l in zip ( logits , map ( len , prev_list ) ) ]
2023-09-09 01:30:54 +00:00
2024-10-12 03:36:06 +00:00
# perform min_p filtering of our logits
if min_p > 0.0 :
logits = [ min_p_filtering ( logit , min_p = min_p ) for logit in logits ]
2023-09-09 01:30:54 +00:00
# perform top_k/top_p filtering of our logits
2023-09-13 02:28:07 +00:00
if top_k > 0 or top_p < 1.0 :
logits = [ top_k_top_p_filtering ( logit , top_k = top_k , top_p = top_p ) for logit in logits ]
2023-10-10 22:02:33 +00:00
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
2024-01-27 01:41:12 +00:00
# epsilon float comparison because I don't trust Python
2023-10-10 22:02:33 +00:00
if abs ( temperature - min_temperature ) > = 0.001 :
logits = [ dynamic_temperature ( logit , temperature = temperature , min_temperature = min_temperature ) for logit in logits ]
2024-10-18 21:55:00 +00:00
elif temperature > 0.0 :
2023-10-09 18:01:40 +00:00
logits = [ logit / temperature for logit in logits ]
2024-11-13 04:30:09 +00:00
# do top-no logit processing
if top_no > 0.0 :
logits = [ top_no_logits_processing ( logit ) for logit in logits ]
2024-07-30 00:15:07 +00:00
# do DRY sampling
2024-10-22 23:12:39 +00:00
if dry_multiplier > 0.0 and prev_list is not None :
2024-11-12 03:40:19 +00:00
logits = [ dry_sampling ( logit , previous = prevs , factor = dry_multiplier , base = dry_base , allowed_length = dry_allowed_length ) for logit , prevs in zip ( logits , prev_list ) ]
2024-07-30 00:15:07 +00:00
2023-09-18 23:55:41 +00:00
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None :
# mirostat sampling
2024-10-12 02:18:26 +00:00
scores = [ mirostat_sample ( logit , state = state ) for logit , state in zip ( logits , mirostat ) ]
res = [ state [ " token " ] for state in scores ]
2023-09-13 02:28:07 +00:00
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
2023-09-13 18:19:11 +00:00
# returns the logit scores as well to be P-concatted with the previous scores
2023-09-13 02:28:07 +00:00
# to-do: not naively implement beam searching
2024-10-12 02:18:26 +00:00
elif beam_width > 1 :
2023-09-13 18:19:11 +00:00
candidates = top_k_logits_list ( logits , beam_width )
2023-10-11 17:25:31 +00:00
res = [ torch . tensor ( token , dtype = torch . int16 ) . unsqueeze ( dim = - 1 ) for batch , token in candidates ]
scores = [ logits [ batch ] . flatten ( ) [ token ] for batch , token in candidates ]
2024-10-12 02:18:26 +00:00
# basic sampling
else :
2024-10-18 21:55:00 +00:00
# argmax instead
if temperature < = 0.0 :
2024-11-10 04:57:34 +00:00
res = [ logit . argmax ( dim = - 1 ) for logit in logits ]
2024-10-18 21:55:00 +00:00
else :
res = [ Categorical ( logits = logit ) . sample ( ) for logit in logits ]
2023-09-13 02:28:07 +00:00
2024-11-02 16:49:05 +00:00
# calculate token probabilities
2024-11-10 18:19:48 +00:00
scores = [
[ F . softmax ( logit [ i , : ] , dim = - 1 ) [ token ] . item ( ) for i , token in enumerate ( tokens ) ]
for logit , tokens in zip ( logits , res )
]
2024-11-02 16:49:05 +00:00
2024-12-25 05:14:32 +00:00
return Sampled ( res , logits , scores , entropy )
# this is a VERY basic implementation to test if a HF-ified model works (it sort of does)
if __name__ == " __main__ " :
from transformers import LlamaForCausalLM , LlamaTokenizer
from . . models import download_model , DEFAULT_MODEL_PATH
from . . emb . qnt import decode_to_file
from . . utils . io import torch_load
# hack in a non-causal mask
def _update_noncausal_mask (
attention_mask ,
inputs_embeds ,
cache_positions ,
past_key_values_length ,
output_attentions ,
) :
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
bsz , seq_len , _ = inputs_embeds . size ( )
# generate default mask based on input
if attention_mask is None :
attention_mask = torch . ones ( ( bsz , seq_len ) , dtype = torch . bool , device = inputs_embeds . device )
# make square
expanded_mask = attention_mask [ : , None , None , : ] . expand ( bsz , 1 , seq_len , seq_len ) . to ( dtype = inputs_embeds . dtype )
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
inverted_mask = 1.0 - expanded_mask
return inverted_mask . masked_fill ( inverted_mask . to ( dtype = torch . bool ) , torch . finfo ( inputs_embeds . dtype ) . min )
device = " cuda "
dtype = torch . bfloat16
is_from_pretrained = True
if is_from_pretrained :
# tokenizer = LlamaTokenizer.from_pretrained("ecker/vall-e", revision="hf")
hf_model = LlamaForCausalLM . from_pretrained ( " ecker/vall-e " , revision = " hf " )
hf_model . to ( device = device , dtype = dtype )
hf_model . eval ( )
model = hf_model . model
else :
download_model ( )
model = LlamaModel ( LlamaConfig (
vocab_size = 1024 ,
hidden_size = 1024 ,
max_position_embeddings = 75 * 60 * 5 , # max-length of 60 seconds
intermediate_size = 1024 * 4 ,
num_hidden_layers = 12 ,
num_attention_heads = 16 ,
attention_dropout = 0.0 ,
num_key_value_heads = 16 ,
sliding_window = 75 * 12 , # 12 second context window
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
) )
state_dict = torch_load ( DEFAULT_MODEL_PATH ) [ ' module ' ]
state_dict_model = { }
for k , v in state_dict . items ( ) :
if not k . startswith ( ' model. ' ) :
continue
state_dict_model [ k . replace ( " model. " , " " ) ] = v
model . load_state_dict ( state_dict_model , strict = False )
model . to ( device = device , dtype = dtype )
model . eval ( )
model . _original_update_causal_mask = model . _update_causal_mask
model . _update_noncausal_mask = _update_noncausal_mask
phn = [ 1 , 22 , 111 , 100 , 4 , 37 , 115 , 169 , 11 , 2 ]
prom = [
[ 62 , 835 , 835 , 835 , 339 , 395 , 798 , 537 , 537 , 537 , 537 , 222 , 76 , 989 , 548 , 65 , 705 , 375 , 261 , 375 , 297 , 503 , 529 , 571 , 707 , 346 , 266 , 862 , 148 , 496 , 574 , 115 , 115 , 438 , 934 , 339 , 865 , 876 , 63 , 40 , 779 , 461 , 602 , 794 , 10 , 220 , 507 , 869 , 639 , 705 , 869 , 917 , 705 , 893 , 917 , 705 , 869 , 938 , 439 , 175 , 139 , 506 , 375 , 529 , 297 , 705 , 651 , 238 , 962 , 461 , 195 , 441 , 377 , 581 , 473 , 795 , 644 , 626 , 459 , 981 , 767 , 670 , 696 , 73 , 779 , 257 , 738 , 1017 , 1019 , 133 , 133 , 1017 , 835 , 604 , 699 , 626 , 67 , 92 , 707 , 92 , 179 , 179 , 772 , 869 , 441 , 799 , 630 , 238 , 745 , 904 , 904 , 904 , 106 , 133 , 133 , 1017 , 1017 , 395 , 883 , 87 , 519 , 594 , 1002 , 682 , 996 , 540 , 186 , 855 , 430 , 202 , 347 , 889 , 61 , 92 , 542 , 297 , 67 , 669 , 571 , 707 , 346 , 67 , 359 , 571 , 707 , 669 , 604 , 395 , 1008 , 810 , 35 , 621 , 67 , 600 , 333 , 123 , 284 , 568 , 817 , 243 , 778 , 464 , 638 , 610 , 359 , 538 , 464 , 975 , 321 , 700 , 377 , 484 , 179 , 284 , 284 , 621 , 538 , 464 , 745 , 171 , 171 , 159 , 744 , 744 , 287 , 461 , 69 , 15 , 529 , 67 , 92 , 669 , 464 , 515 , 605 , 24 , 822 , 865 , 293 , 865 , 172 , 638 , 359 , 562 , 138 , 839 , 846 , 775 , 556 , 775 , 1006 , 917 , 346 , 312 , 148 , 331 , 496 , 646 , 67 , 314 , 15 , 705 , 131 , 855 , 662 , 287 , 172 , 85 , 107 , 519 , 374 , 450 , 391 , 609 , 643 , 778 , 80 , 287 , 794 , 794 , 115 , 785 , 794 , 461 , 699 , 519 , 932 , 522 , 652 , 262 , 508 , 902 , 932 , 932 , 391 , 769 , 18 , 507 , 90 , 442 , 762 , 610 , 610 , 669 , 605 , 35 , 855 , 56 , 989 , 863 , 195 , 464 , 604 , 257 , 904 , 632 , 786 , 951 , 461 , 239 , 195 , 878 , 771 , 146 , 481 , 146 , 481 , 434 , 643 , 917 , 280 , 67 , 464 , 115 , 744 , 744 , 115 , 115 , 115 , 819 , 709 , 63 , 907 , 359 , 519 , 996 , 616 , 682 , 996 , 616 , 519 , 762 , 917 , 841 , 772 , 568 , 954 , 600 , 422 , 893 , 592 , 464 , 626 , 86 , 143 , 615 , 171 , 744 , 744 , 196 , 115 , 821 , 415 , 521 , 799 , 654 , 839 , 644 , 473 , 592 , 953 , 523 , 855 , 738 , 855 , 855 , 876 , 1017 , 63 , 329 ] ,
[ 913 , 859 , 740 , 740 , 937 , 601 , 961 , 961 , 877 , 747 , 747 , 559 , 474 , 618 , 20 , 316 , 58 , 316 , 180 , 112 , 290 , 869 , 610 , 869 , 869 , 943 , 127 , 153 , 236 , 794 , 282 , 857 , 984 , 196 , 875 , 648 , 993 , 913 , 860 , 616 , 38 , 833 , 620 , 133 , 123 , 992 , 247 , 367 , 252 , 50 , 298 , 27 , 27 , 631 , 163 , 784 , 271 , 20 , 843 , 514 , 869 , 258 , 180 , 66 , 803 , 281 , 123 , 493 , 831 , 102 , 556 , 992 , 385 , 122 , 31 , 251 , 990 , 827 , 26 , 347 , 460 , 43 , 43 , 460 , 228 , 43 , 841 , 913 , 302 , 544 , 544 , 277 , 859 , 404 , 646 , 775 , 315 , 848 , 726 , 185 , 203 , 314 , 203 , 174 , 252 , 174 , 378 , 954 , 214 , 993 , 924 , 809 , 277 , 765 , 363 , 544 , 363 , 518 , 791 , 185 , 454 , 193 , 193 , 193 , 193 , 193 , 573 , 977 , 924 , 76 , 434 , 56 , 193 , 962 , 610 , 24 , 954 , 459 , 396 , 112 , 903 , 137 , 398 , 474 , 506 , 791 , 839 , 399 , 102 , 25 , 205 , 792 , 459 , 474 , 526 , 817 , 869 , 192 , 792 , 593 , 878 , 506 , 24 , 410 , 539 , 788 , 522 , 667 , 566 , 584 , 588 , 992 , 444 , 24 , 869 , 925 , 635 , 393 , 903 , 742 , 320 , 1023 , 833 , 136 , 216 , 924 , 220 , 24 , 563 , 630 , 968 , 96 , 708 , 24 , 708 , 127 , 399 , 364 , 67 , 740 , 381 , 981 , 203 , 248 , 607 , 744 , 252 , 996 , 474 , 582 , 248 , 527 , 423 , 25 , 387 , 94 , 229 , 775 , 122 , 474 , 792 , 367 , 650 , 371 , 413 , 448 , 448 , 784 , 506 , 795 , 848 , 298 , 27 , 526 , 96 , 905 , 70 , 693 , 956 , 1002 , 1002 , 37 , 747 , 857 , 993 , 124 , 193 , 193 , 193 , 193 , 732 , 732 , 732 , 992 , 447 , 792 , 929 , 291 , 289 , 524 , 451 , 27 , 27 , 524 , 202 , 693 , 374 , 1002 , 125 , 732 , 585 , 367 , 317 , 679 , 395 , 413 , 189 , 493 , 386 , 650 , 110 , 912 , 505 , 384 , 399 , 851 , 367 , 367 , 27 , 230 , 988 , 810 , 975 , 842 , 956 , 1002 , 4 , 551 , 729 , 956 , 1002 , 750 , 648 , 231 , 950 , 193 , 96 , 912 , 410 , 732 , 539 , 103 , 193 , 904 , 491 , 213 , 792 , 792 , 998 , 193 , 399 , 151 , 410 , 96 , 673 , 497 , 1002 , 241 , 833 , 956 , 630 , 43 , 399 , 775 , 732 , 792 , 792 , 792 , 792 , 917 , 750 , 185 , 812 , 812 , 700 , 859 , 841 , 363 , 833 , 630 ] ,
[ 786 , 36 , 821 , 937 , 1000 , 705 , 1016 , 345 , 345 , 470 , 165 , 581 , 95 , 404 , 95 , 95 , 1006 , 477 , 95 , 95 , 691 , 254 , 997 , 657 , 176 , 124 , 95 , 673 , 489 , 326 , 218 , 437 , 907 , 590 , 752 , 541 , 1016 , 821 , 445 , 563 , 181 , 555 , 181 , 345 , 576 , 190 , 987 , 0 , 265 , 997 , 488 , 12 , 598 , 687 , 152 , 108 , 52 , 95 , 95 , 71 , 87 , 945 , 95 , 997 , 754 , 488 , 955 , 694 , 925 , 82 , 18 , 1020 , 1006 , 542 , 788 , 441 , 325 , 532 , 246 , 132 , 560 , 532 , 947 , 655 , 653 , 842 , 732 , 36 , 36 , 829 , 36 , 937 , 989 , 989 , 752 , 651 , 87 , 489 , 677 , 260 , 789 , 462 , 95 , 227 , 986 , 955 , 95 , 810 , 624 , 435 , 280 , 868 , 832 , 879 , 863 , 821 , 829 , 937 , 168 , 270 , 489 , 544 , 909 , 562 , 957 , 0 , 593 , 714 , 675 , 690 , 626 , 227 , 794 , 489 , 489 , 563 , 489 , 298 , 269 , 741 , 249 , 516 , 360 , 240 , 516 , 336 , 93 , 808 , 1022 , 682 , 555 , 737 , 147 , 405 , 476 , 895 , 323 , 694 , 412 , 689 , 963 , 72 , 193 , 298 , 181 , 521 , 741 , 193 , 93 , 153 , 773 , 677 , 689 , 495 , 30 , 564 , 719 , 1020 , 559 , 940 , 53 , 53 , 53 , 929 , 360 , 971 , 403 , 1012 , 997 , 919 , 957 , 433 , 919 , 787 , 401 , 401 , 355 , 276 , 370 , 414 , 690 , 697 , 330 , 629 , 552 , 930 , 720 , 259 , 579 , 221 , 62 , 945 , 135 , 1020 , 626 , 663 , 401 , 153 , 997 , 381 , 830 , 185 , 587 , 853 , 207 , 126 , 66 , 529 , 410 , 113 , 997 , 488 , 431 , 563 , 488 , 488 , 719 , 746 , 790 , 296 , 843 , 752 , 790 , 23 , 984 , 292 , 41 , 27 , 120 , 249 , 124 , 900 , 358 , 801 , 227 , 978 , 95 , 997 , 997 , 997 , 371 , 561 , 86 , 388 , 52 , 667 , 601 , 894 , 545 , 997 , 498 , 900 , 494 , 365 , 852 , 986 , 95 , 841 , 664 , 256 , 18 , 1020 , 963 , 901 , 447 , 498 , 262 , 388 , 691 , 997 , 646 , 651 , 757 , 468 , 114 , 601 , 437 , 940 , 212 , 655 , 541 , 970 , 870 , 521 , 237 , 957 , 563 , 794 , 563 , 564 , 620 , 489 , 351 , 489 , 489 , 257 , 733 , 629 , 489 , 227 , 622 , 962 , 7 , 598 , 374 , 470 , 114 , 159 , 211 , 298 , 363 , 843 , 818 , 153 , 59 , 452 , 529 , 258 , 419 , 605 , 689 , 526 , 39 , 982 , 829 , 982 , 752 , 678 , 1005 , 312 ] ,
[ 673 , 673 , 919 , 866 , 762 , 961 , 52 , 674 , 528 , 528 , 675 , 526 , 12 , 753 , 297 , 967 , 661 , 845 , 482 , 303 , 338 , 1021 , 506 , 445 , 247 , 214 , 206 , 94 , 434 , 799 , 210 , 885 , 552 , 695 , 853 , 1022 , 916 , 762 , 764 , 721 , 445 , 434 , 529 , 999 , 771 , 708 , 767 , 498 , 282 , 736 , 227 , 150 , 299 , 12 , 536 , 767 , 321 , 561 , 12 , 530 , 147 , 530 , 262 , 325 , 196 , 990 , 874 , 997 , 944 , 875 , 426 , 12 , 282 , 571 , 571 , 282 , 365 , 534 , 365 , 424 , 89 , 388 , 563 , 222 , 31 , 1019 , 624 , 74 , 215 , 651 , 1018 , 74 , 956 , 1022 , 74 , 18 , 633 , 350 , 72 , 448 , 454 , 769 , 267 , 938 , 12 , 534 , 929 , 723 , 829 , 614 , 505 , 364 , 1018 , 1014 , 838 , 673 , 919 , 74 , 223 , 761 , 266 , 78 , 177 , 736 , 20 , 718 , 425 , 1001 , 366 , 58 , 874 , 58 , 153 , 627 , 312 , 197 , 801 , 530 , 767 , 674 , 196 , 633 , 327 , 425 , 376 , 413 , 1019 , 209 , 594 , 383 , 744 , 458 , 468 , 711 , 282 , 885 , 640 , 435 , 655 , 571 , 556 , 1020 , 310 , 116 , 273 , 116 , 504 , 633 , 15 , 736 , 633 , 448 , 662 , 612 , 487 , 345 , 19 , 612 , 665 , 556 , 198 , 778 , 705 , 403 , 706 , 31 , 196 , 197 , 536 , 805 , 427 , 339 , 161 , 241 , 116 , 504 , 58 , 945 , 853 , 734 , 670 , 424 , 807 , 19 , 397 , 175 , 144 , 419 , 19 , 221 , 697 , 68 , 321 , 800 , 210 , 824 , 972 , 712 , 911 , 362 , 427 , 694 , 182 , 651 , 972 , 863 , 684 , 887 , 548 , 806 , 27 , 627 , 639 , 432 , 193 , 103 , 198 , 436 , 837 , 366 , 212 , 125 , 1001 , 493 , 874 , 808 , 17 , 17 , 127 , 204 , 530 , 300 , 345 , 425 , 246 , 240 , 640 , 906 , 340 , 310 , 633 , 246 , 774 , 114 , 633 , 522 , 777 , 874 , 494 , 577 , 353 , 939 , 571 , 693 , 857 , 722 , 530 , 521 , 354 , 492 , 735 , 214 , 806 , 483 , 736 , 530 , 118 , 234 , 536 , 177 , 132 , 522 , 349 , 259 , 436 , 973 , 528 , 414 , 224 , 762 , 212 , 854 , 744 , 271 , 568 , 127 , 323 , 736 , 304 , 499 , 499 , 78 , 536 , 736 , 805 , 232 , 126 , 468 , 566 , 611 , 52 , 339 , 450 , 258 , 157 , 602 , 594 , 854 , 602 , 599 , 82 , 124 , 472 , 563 , 666 , 174 , 936 , 818 , 66 , 758 , 627 , 52 , 350 , 999 , 734 , 215 , 919 , 1018 , 874 , 885 ] ,
[ 528 , 448 , 646 , 190 , 222 , 884 , 939 , 907 , 907 , 673 , 413 , 786 , 527 , 517 , 710 , 449 , 119 , 531 , 565 , 762 , 531 , 501 , 522 , 246 , 162 , 871 , 8 , 594 , 206 , 937 , 462 , 712 , 862 , 151 , 103 , 261 , 882 , 990 , 1007 , 314 , 683 , 864 , 693 , 812 , 319 , 786 , 107 , 531 , 31 , 342 , 632 , 460 , 269 , 429 , 531 , 531 , 717 , 417 , 321 , 671 , 1015 , 152 , 467 , 863 , 285 , 875 , 941 , 417 , 475 , 825 , 596 , 957 , 117 , 460 , 162 , 162 , 117 , 630 , 735 , 527 , 272 , 558 , 38 , 39 , 605 , 375 , 39 , 900 , 862 , 646 , 712 , 804 , 622 , 963 , 407 , 93 , 828 , 796 , 306 , 415 , 70 , 667 , 371 , 531 , 1000 , 411 , 710 , 162 , 812 , 381 , 673 , 498 , 691 , 884 , 928 , 712 , 528 , 48 , 630 , 24 , 593 , 901 , 973 , 579 , 722 , 75 , 139 , 909 , 919 , 328 , 764 , 393 , 777 , 753 , 512 , 577 , 175 , 577 , 512 , 922 , 834 , 863 , 30 , 69 , 94 , 68 , 616 , 691 , 835 , 335 , 486 , 345 , 306 , 374 , 732 , 938 , 580 , 311 , 715 , 495 , 527 , 1008 , 306 , 369 , 663 , 512 , 369 , 320 , 360 , 80 , 42 , 1021 , 1021 , 1021 , 175 , 568 , 526 , 362 , 320 , 317 , 488 , 613 , 937 , 548 , 966 , 545 , 596 , 177 , 306 , 480 , 522 , 577 , 512 , 512 , 638 , 1008 , 82 , 100 , 696 , 89 , 714 , 531 , 639 , 460 , 679 , 718 , 492 , 509 , 492 , 624 , 460 , 572 , 531 , 306 , 19 , 473 , 915 , 558 , 285 , 319 , 713 , 1018 , 381 , 877 , 667 , 425 , 905 , 43 , 437 , 632 , 634 , 324 , 306 , 207 , 324 , 303 , 48 , 69 , 467 , 39 , 902 , 599 , 3 , 617 , 465 , 78 , 918 , 459 , 1009 , 427 , 751 , 145 , 531 , 349 , 356 , 1021 , 157 , 507 , 780 , 624 , 165 , 507 , 144 , 270 , 94 , 414 , 899 , 379 , 947 , 994 , 853 , 107 , 586 , 652 , 877 , 92 , 19 , 91 , 188 , 544 , 624 , 470 , 503 , 513 , 13 , 192 , 563 , 145 , 531 , 618 , 743 , 470 , 62 , 701 , 499 , 436 , 679 , 505 , 198 , 959 , 3 , 766 , 839 , 437 , 491 , 395 , 1021 , 512 , 306 , 512 , 356 , 851 , 1021 , 1021 , 78 , 690 , 856 , 735 , 286 , 280 , 4 , 1008 , 369 , 359 , 309 , 651 , 864 , 561 , 170 , 692 , 952 , 877 , 520 , 959 , 306 , 37 , 1021 , 31 , 236 , 162 , 773 , 522 , 254 , 446 , 606 , 691 , 804 , 882 , 58 , 974 ] ,
[ 1011 , 939 , 881 , 881 , 140 , 937 , 724 , 724 , 937 , 1011 , 381 , 229 , 965 , 251 , 745 , 69 , 305 , 206 , 566 , 813 , 503 , 116 , 940 , 127 , 353 , 621 , 57 , 779 , 595 , 744 , 755 , 530 , 701 , 862 , 760 , 443 , 293 , 768 , 156 , 281 , 960 , 504 , 327 , 979 , 55 , 790 , 545 , 953 , 830 , 759 , 667 , 485 , 861 , 63 , 485 , 55 , 898 , 581 , 520 , 49 , 99 , 651 , 940 , 945 , 685 , 621 , 728 , 487 , 650 , 530 , 934 , 378 , 522 , 522 , 522 , 996 , 534 , 522 , 739 , 534 , 378 , 543 , 94 , 602 , 390 , 948 , 692 , 692 , 41 , 41 , 768 , 412 , 982 , 692 , 692 , 774 , 176 , 791 , 526 , 497 , 57 , 940 , 542 , 685 , 694 , 916 , 813 , 890 , 357 , 193 , 430 , 863 , 929 , 412 , 412 , 903 , 140 , 763 , 465 , 707 , 569 , 925 , 859 , 985 , 24 , 411 , 835 , 298 , 293 , 791 , 837 , 460 , 182 , 296 , 137 , 474 , 809 , 111 , 376 , 1021 , 111 , 490 , 111 , 938 , 542 , 578 , 477 , 506 , 57 , 385 , 300 , 873 , 240 , 104 , 667 , 204 , 515 , 834 , 24 , 125 , 113 , 980 , 111 , 997 , 859 , 997 , 376 , 193 , 490 , 824 , 511 , 799 , 719 , 575 , 451 , 575 , 251 , 222 , 630 , 429 , 920 , 788 , 300 , 993 , 641 , 154 , 816 , 940 , 618 , 130 , 940 , 462 , 823 , 955 , 1001 , 569 , 508 , 632 , 2 , 903 , 399 , 333 , 709 , 489 , 726 , 932 , 725 , 777 , 970 , 843 , 717 , 940 , 211 , 534 , 274 , 161 , 392 , 103 , 31 , 462 , 813 , 985 , 638 , 213 , 352 , 219 , 236 , 381 , 287 , 111 , 87 , 818 , 953 , 112 , 336 , 980 , 1016 , 72 , 960 , 426 , 238 , 60 , 9 , 487 , 665 , 129 , 24 , 24 , 162 , 312 , 411 , 111 , 157 , 473 , 466 , 222 , 940 , 341 , 55 , 457 , 712 , 179 , 451 , 111 , 831 , 918 , 826 , 814 , 940 , 30 , 468 , 240 , 207 , 389 , 923 , 186 , 95 , 300 , 876 , 679 , 576 , 543 , 582 , 111 , 227 , 312 , 112 , 545 , 747 , 378 , 165 , 158 , 610 , 601 , 425 , 238 , 704 , 630 , 124 , 644 , 949 , 982 , 297 , 868 , 569 , 24 , 57 , 465 , 24 , 859 , 111 , 24 , 752 , 775 , 24 , 647 , 465 , 495 , 57 , 24 , 57 , 227 , 907 , 296 , 581 , 843 , 1013 , 514 , 555 , 319 , 937 , 347 , 478 , 186 , 684 , 15 , 241 , 534 , 369 , 381 , 846 , 578 , 314 , 711 , 814 , 435 , 41 , 986 , 673 , 991 ] ,
[ 485 , 748 , 562 , 562 , 485 , 380 , 834 , 997 , 78 , 963 , 755 , 142 , 978 , 135 , 362 , 421 , 217 , 79 , 530 , 1012 , 972 , 946 , 127 , 587 , 838 , 818 , 456 , 548 , 424 , 479 , 944 , 650 , 694 , 447 , 391 , 616 , 938 , 908 , 206 , 259 , 998 , 292 , 818 , 128 , 353 , 273 , 566 , 796 , 333 , 146 , 110 , 986 , 571 , 451 , 166 , 229 , 421 , 300 , 911 , 689 , 329 , 145 , 287 , 273 , 542 , 808 , 301 , 491 , 0 , 278 , 825 , 442 , 0 , 100 , 818 , 826 , 66 , 904 , 642 , 566 , 135 , 305 , 999 , 993 , 905 , 485 , 755 , 782 , 365 , 977 , 485 , 1015 , 570 , 1002 , 755 , 169 , 967 , 36 , 721 , 1019 , 273 , 931 , 273 , 166 , 216 , 31 , 346 , 946 , 32 , 290 , 362 , 828 , 464 , 748 , 782 , 1002 , 1015 , 755 , 1014 , 100 , 315 , 777 , 549 , 177 , 882 , 110 , 603 , 975 , 531 , 608 , 67 , 1011 , 950 , 465 , 368 , 416 , 798 , 941 , 635 , 602 , 553 , 300 , 200 , 644 , 498 , 325 , 786 , 734 , 342 , 222 , 403 , 1 , 716 , 175 , 899 , 273 , 40 , 333 , 999 , 74 , 54 , 644 , 408 , 976 , 407 , 631 , 577 , 338 , 435 , 612 , 333 , 273 , 162 , 709 , 882 , 555 , 384 , 995 , 173 , 459 , 442 , 72 , 72 , 200 , 72 , 711 , 219 , 282 , 716 , 442 , 431 , 801 , 976 , 130 , 622 , 72 , 582 , 384 , 516 , 772 , 0 , 440 , 1001 , 249 , 1 , 953 , 65 , 945 , 438 , 249 , 511 , 561 , 205 , 507 , 821 , 998 , 427 , 746 , 290 , 544 , 426 , 693 , 999 , 190 , 214 , 167 , 219 , 534 , 166 , 325 , 975 , 414 , 326 , 326 , 268 , 679 , 991 , 418 , 868 , 445 , 632 , 160 , 380 , 890 , 346 , 315 , 806 , 258 , 806 , 486 , 326 , 797 , 471 , 18 , 790 , 33 , 66 , 63 , 66 , 224 , 38 , 599 , 599 , 110 , 801 , 761 , 18 , 936 , 230 , 253 , 171 , 393 , 774 , 887 , 887 , 403 , 466 , 495 , 524 , 261 , 666 , 256 , 687 , 759 , 263 , 713 , 185 , 454 , 242 , 988 , 185 , 161 , 911 , 430 , 86 , 550 , 439 , 327 , 527 , 671 , 782 , 383 , 916 , 590 , 315 , 806 , 583 , 465 , 785 , 321 , 315 , 421 , 856 , 66 , 352 , 0 , 634 , 540 , 362 , 948 , 185 , 16 , 224 , 372 , 694 , 259 , 648 , 87 , 733 , 659 , 603 , 67 , 269 , 901 , 66 , 566 , 173 , 705 , 746 , 566 , 911 , 10 , 743 , 860 , 78 , 782 , 1002 , 755 , 389 , 175 ] ,
[ 948 , 948 , 975 , 975 , 948 , 322 , 672 , 639 , 902 , 55 , 916 , 439 , 498 , 389 , 407 , 682 , 451 , 401 , 386 , 440 , 499 , 348 , 736 , 891 , 603 , 762 , 783 , 407 , 886 , 76 , 543 , 699 , 137 , 458 , 639 , 253 , 63 , 475 , 55 , 436 , 502 , 888 , 542 , 131 , 524 , 167 , 738 , 131 , 907 , 29 , 378 , 545 , 227 , 382 , 478 , 399 , 218 , 872 , 917 , 202 , 330 , 2 , 371 , 264 , 667 , 355 , 1016 , 768 , 590 , 408 , 463 , 542 , 214 , 202 , 715 , 891 , 840 , 297 , 509 , 689 , 290 , 439 , 672 , 714 , 528 , 940 , 1019 , 534 , 975 , 475 , 1019 , 835 , 975 , 558 , 975 , 981 , 330 , 635 , 96 , 858 , 606 , 627 , 367 , 191 , 191 , 669 , 40 , 873 , 359 , 267 , 701 , 426 , 210 , 1012 , 899 , 975 , 475 , 1012 , 610 , 6 , 300 , 749 , 231 , 616 , 877 , 631 , 720 , 574 , 551 , 398 , 503 , 789 , 684 , 664 , 390 , 277 , 150 , 990 , 823 , 190 , 971 , 903 , 175 , 863 , 316 , 965 , 988 , 988 , 800 , 612 , 336 , 506 , 242 , 847 , 389 , 939 , 415 , 202 , 83 , 317 , 2 , 153 , 365 , 363 , 57 , 2 , 891 , 965 , 300 , 754 , 763 , 426 , 555 , 621 , 303 , 415 , 367 , 902 , 829 , 741 , 119 , 380 , 902 , 25 , 884 , 439 , 822 , 49 , 76 , 760 , 566 , 316 , 249 , 555 , 774 , 955 , 834 , 309 , 859 , 173 , 935 , 812 , 682 , 586 , 141 , 606 , 197 , 131 , 644 , 631 , 913 , 586 , 202 , 117 , 810 , 884 , 76 , 592 , 754 , 531 , 586 , 925 , 649 , 583 , 145 , 816 , 821 , 283 , 871 , 1017 , 316 , 377 , 646 , 339 , 201 , 76 , 780 , 76 , 976 , 217 , 38 , 598 , 977 , 617 , 825 , 833 , 49 , 231 , 749 , 749 , 633 , 205 , 231 , 271 , 50 , 249 , 684 , 555 , 982 , 526 , 895 , 288 , 22 , 57 , 722 , 996 , 260 , 1018 , 110 , 833 , 644 , 738 , 648 , 468 , 798 , 297 , 769 , 282 , 197 , 402 , 465 , 510 , 194 , 930 , 182 , 909 , 749 , 986 , 187 , 187 , 917 , 38 , 38 , 985 , 985 , 988 , 815 , 878 , 814 , 459 , 237 , 768 , 781 , 649 , 683 , 749 , 934 , 729 , 463 , 181 , 625 , 231 , 917 , 96 , 499 , 839 , 720 , 439 , 842 , 205 , 808 , 338 , 617 , 681 , 326 , 446 , 905 , 346 , 647 , 533 , 49 , 728 , 147 , 432 , 846 , 536 , 586 , 611 , 49 , 879 , 872 , 893 , 859 , 859 , 961 , 989 , 975 , 701 , 495 , 65 ] ,
]
resp = [ ]
"""
resp = [
[ 922 , 738 , 461 , 341 , 341 , 10 , 416 , 416 , 416 , 416 , 346 , 346 , 346 , 346 , 346 , 484 , 484 , 484 , 484 , 484 , 484 , 333 , 442 , 442 , 359 , 359 , 359 , 459 , 459 , 975 , 975 , 626 , 626 , 626 , 626 , 626 , 610 , 359 , 359 , 359 , 359 , 359 , 359 , 359 , 359 , 359 , 610 , 610 , 442 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 90 , 638 , 638 , 638 , 638 , 975 , 975 , 672 , 875 , 63 , 144 ] ,
[ 993 , 700 , 384 , 213 , 794 , 10 , 305 , 778 , 58 , 225 , 118 , 260 , 768 , 768 , 260 , 474 , 903 , 732 , 70 , 992 , 447 , 70 , 1000 , 665 , 848 , 379 , 485 , 934 , 181 , 795 , 438 , 298 , 688 , 324 , 934 , 756 , 395 , 795 , 110 , 328 , 343 , 172 , 768 , 871 , 593 , 355 , 396 , 783 , 24 , 24 , 911 , 20 , 27 , 562 , 697 , 616 , 668 , 27 , 27 , 755 , 20 , 505 , 248 , 79 , 822 , 461 , 197 , 156 , 27 , 492 , 151 , 1013 , 669 , 669 , 562 ] ,
[ 626 , 989 , 936 , 488 , 511 , 624 , 997 , 112 , 112 , 648 , 210 , 650 , 563 , 650 , 41 , 41 , 490 , 920 , 977 , 986 , 920 , 927 , 131 , 167 , 167 , 968 , 346 , 168 , 167 , 168 , 120 , 355 , 766 , 599 , 712 , 390 , 558 , 810 , 948 , 332 , 332 , 867 , 994 , 346 , 955 , 392 , 920 , 452 , 576 , 346 , 52 , 254 , 52 , 307 , 897 , 307 , 968 , 920 , 167 , 563 , 167 , 167 , 167 , 968 , 167 , 488 , 968 , 488 , 1001 , 938 , 563 , 741 , 432 , 566 , 758 ] ,
[ 916 , 874 , 798 , 212 , 496 , 751 , 620 , 616 , 982 , 745 , 975 , 890 , 890 , 141 , 141 , 321 , 321 , 214 , 899 , 42 , 151 , 722 , 310 , 971 , 774 , 35 , 627 , 995 , 27 , 43 , 248 , 248 , 595 , 774 , 942 , 352 , 810 , 35 , 384 , 340 , 654 , 639 , 89 , 214 , 737 , 197 , 657 , 45 , 622 , 321 , 337 , 19 , 483 , 679 , 938 , 938 , 682 , 938 , 938 , 141 , 938 , 310 , 114 , 724 , 116 , 327 , 372 , 607 , 607 , 310 , 204 , 713 , 762 , 853 , 853 ] ,
[ 528 , 222 , 992 , 727 , 536 , 191 , 202 , 483 , 306 , 568 , 533 , 577 , 398 , 533 , 202 , 24 , 753 , 753 , 739 , 739 , 643 , 513 , 4 , 324 , 369 , 66 , 447 , 201 , 66 , 802 , 66 , 957 , 665 , 526 , 602 , 749 , 483 , 447 , 193 , 853 , 531 , 201 , 201 , 71 , 888 , 202 , 66 , 66 , 650 , 228 , 533 , 102 , 639 , 513 , 533 , 531 , 533 , 471 , 344 , 566 , 201 , 639 , 471 , 639 , 732 , 594 , 464 , 308 , 116 , 533 , 116 , 174 , 959 , 621 , 539 ] ,
[ 692 , 632 , 478 , 375 , 910 , 857 , 775 , 503 , 503 , 193 , 717 , 548 , 344 , 717 , 55 , 808 , 162 , 112 , 112 , 112 , 543 , 582 , 847 , 712 , 691 , 679 , 427 , 940 , 369 , 475 , 153 , 526 , 729 , 269 , 323 , 721 , 526 , 211 , 191 , 192 , 685 , 844 , 731 , 813 , 914 , 545 , 582 , 712 , 925 , 916 , 375 , 111 , 340 , 162 , 844 , 940 , 844 , 162 , 844 , 990 , 111 , 491 , 232 , 582 , 491 , 582 , 618 , 121 , 1020 , 664 , 670 , 254 , 315 , 438 , 723 ] ,
[ 365 , 908 , 896 , 819 , 206 , 153 , 515 , 471 , 75 , 79 , 664 , 145 , 145 , 801 , 135 , 321 , 79 , 216 , 233 , 223 , 79 , 66 , 724 , 517 , 135 , 474 , 818 , 818 , 105 , 892 , 971 , 337 , 818 , 19 , 932 , 981 , 469 , 135 , 163 , 75 , 135 , 818 , 999 , 555 , 135 , 710 , 256 , 105 , 590 , 31 , 539 , 1003 , 517 , 130 , 445 , 40 , 549 , 130 , 859 , 385 , 1003 , 1003 , 549 , 33 , 286 , 932 , 329 , 774 , 321 , 664 , 686 , 16 , 834 , 703 , 290 ] ,
[ 899 , 237 , 832 , 748 , 425 , 121 , 460 , 872 , 391 , 586 , 857 , 215 , 306 , 76 , 306 , 554 , 187 , 57 , 482 , 406 , 802 , 555 , 710 , 895 , 448 , 517 , 506 , 316 , 18 , 772 , 779 , 697 , 855 , 1005 , 792 , 96 , 402 , 96 , 517 , 775 , 506 , 938 , 114 , 986 , 986 , 503 , 749 , 984 , 524 , 527 , 506 , 749 , 463 , 490 , 188 , 374 , 506 , 49 , 537 , 188 , 494 , 900 , 526 , 524 , 524 , 500 , 500 , 345 , 630 , 338 , 982 , 761 , 700 , 598 , 749 ] ,
]
"""
# name, (start, end), classifier, src_name
io_map = {
' text ' : [ ( 0 , 256 ) , 9 , " text_emb.weight " ] ,
' rvq_l ' : [ ( 256 , 264 ) , - 1 , " rvq_l_emb.weight " ] ,
' lang ' : [ ( 264 , 270 ) , - 1 , " langs_emb.weight " ] ,
' task ' : [ ( 270 , 279 ) , - 1 , " tasks_emb.weight " ] ,
' len ' : [ ( 279 , 290 ) , 10 , " len_emb.weight " ] ,
' tone ' : [ ( 290 , 291 ) , - 1 , " tones_emb.weight " ] ,
' sep ' : [ ( 291 , 292 ) , - 1 , " sep " ] ,
' prom|0 ' : [ ( 292 , 1316 ) , - 1 , " proms_emb.embeddings.0.weight " ] ,
' prom|1 ' : [ ( 1316 , 2340 ) , - 1 , " proms_emb.embeddings.1.weight " ] ,
' prom|2 ' : [ ( 2340 , 3364 ) , - 1 , " proms_emb.embeddings.2.weight " ] ,
' prom|3 ' : [ ( 3364 , 4388 ) , - 1 , " proms_emb.embeddings.3.weight " ] ,
' prom|4 ' : [ ( 4388 , 5412 ) , - 1 , " proms_emb.embeddings.4.weight " ] ,
' prom|5 ' : [ ( 5412 , 6436 ) , - 1 , " proms_emb.embeddings.5.weight " ] ,
' prom|6 ' : [ ( 6436 , 7460 ) , - 1 , " proms_emb.embeddings.6.weight " ] ,
' prom|7 ' : [ ( 7460 , 8484 ) , - 1 , " proms_emb.embeddings.7.weight " ] ,
' resp|AR:0:0 ' : [ ( 8484 , 9509 ) , 0 , " resps_emb.embeddings.0.weight " ] ,
' resp|NAR:0:1 ' : [ ( 9509 , 10533 ) , 1 , " resps_emb.embeddings.1.weight " ] ,
' resp|NAR:1:2 ' : [ ( 10533 , 11557 ) , 2 , " resps_emb.embeddings.2.weight " ] ,
' resp|NAR:2:3 ' : [ ( 11557 , 12581 ) , 3 , " resps_emb.embeddings.3.weight " ] ,
' resp|NAR:3:4 ' : [ ( 12581 , 13605 ) , 4 , " resps_emb.embeddings.4.weight " ] ,
' resp|NAR:4:5 ' : [ ( 13605 , 14629 ) , 5 , " resps_emb.embeddings.5.weight " ] ,
' resp|NAR:5:6 ' : [ ( 14629 , 15653 ) , 6 , " resps_emb.embeddings.6.weight " ] ,
' resp|NAR:6:7 ' : [ ( 15653 , 16677 ) , 7 , " resps_emb.embeddings.7.weight " ] ,
' resp|NAR:0:0 ' : [ ( 16677 , 17702 ) , 8 , " resps_emb.embeddings.8.weight " ] ,
}
mode_lvl_map = {
' AR:0:0 ' : 0 ,
' NAR:0:1 ' : 1 ,
' NAR:1:2 ' : 2 ,
' NAR:2:3 ' : 3 ,
' NAR:3:4 ' : 4 ,
' NAR:4:5 ' : 5 ,
' NAR:5:6 ' : 6 ,
' NAR:6:7 ' : 7 ,
' NAR:0:0 ' : 0 ,
' len ' : 0 ,
}
embds = { }
heads = { }
n_embd = 1024
with torch . no_grad ( ) :
for k , v in io_map . items ( ) :
start , end = v [ 0 ]
classifier_idx = v [ 1 ]
embd_name = v [ 2 ]
if is_from_pretrained :
n_vocab = end - start
embds [ k ] = torch . nn . Embedding ( n_vocab , n_embd ) . to ( model . embed_tokens . weight )
embds [ k ] . weight [ : ] = model . embed_tokens . weight [ start : end , : ]
if classifier_idx > = 0 :
# NAR:0:0 does not have a masked token output
if k == " resp|NAR:0:0 " :
end - = 1
n_vocab - = 1
heads [ k ] = torch . nn . Linear ( n_embd , n_vocab , bias = False ) . to ( hf_model . lm_head . weight )
heads [ k ] . weight [ : ] = hf_model . lm_head . weight [ start : end , : ]
else :
embd_weight = state_dict [ embd_name ] . unsqueeze ( 0 ) if state_dict [ embd_name ] . dim ( ) == 1 else state_dict [ embd_name ]
embds [ k ] = torch . nn . Embedding ( embd_weight . shape [ 0 ] , embd_weight . shape [ 1 ] ) . to ( device = device , dtype = dtype )
embds [ k ] . load_state_dict ( { " weight " : embd_weight } )
if classifier_idx > = 0 :
head_weight = state_dict [ f ' classifiers.proj. { classifier_idx } .weight ' ]
heads [ k ] = torch . nn . Linear ( head_weight . shape [ 1 ] , head_weight . shape [ 0 ] , bias = False ) . to ( device = device , dtype = dtype )
heads [ k ] . load_state_dict ( { " weight " : head_weight } )
def create_inputs ( phn , prom , lang = 0 , seq = None , mode = " AR:0:0 " ) :
rvq_l = mode_lvl_map [ mode ]
inputs = torch . tensor ( [ ] )
pos_ids = torch . tensor ( [ ] )
attn_mask = torch . tensor ( [ ] )
seqs = [ ]
phn = torch . tensor ( phn , device = device , dtype = torch . int32 )
prom = torch . tensor ( prom , device = device , dtype = torch . int32 )
lang = torch . tensor ( [ lang ] , device = device , dtype = torch . int32 )
rvq_l = torch . tensor ( [ rvq_l ] , device = device , dtype = torch . int32 )
zero = torch . tensor ( [ 0 ] , device = device , dtype = torch . int32 )
if mode == " len " :
seq = zero if not seq else torch . concat ( [ zero , torch . tensor ( seq , device = device , dtype = torch . int32 ) ] )
elif seq :
seq = torch . tensor ( seq , device = device , dtype = torch . int32 )
seq = seq [ : rvq_l , : ] if rvq_l > 0 else seq
sep_embd = embds [ " sep " ] ( zero )
phn_embd = embds [ " text " ] ( phn )
rvq_l_embd = embds [ " rvq_l " ] ( rvq_l )
lang_embd = embds [ " lang " ] ( lang )
prom_embd = torch . zeros ( prom . shape [ - 1 ] , n_embd , device = device , dtype = dtype )
seq_embd = None
for i , p in enumerate ( prom ) :
if i > rvq_l :
break
prom_embd + = embds [ f " prom| { i } " ] ( p )
if seq is not None :
if mode == " len " :
seq_embd = embds [ " len " ] ( seq )
elif mode == " AR:0:0 " :
seq_embd = embds [ " resp|AR:0:0 " ] ( seq )
else :
seq_embd = torch . zeros ( seq . shape [ - 1 ] , n_embd , device = device , dtype = dtype )
for i , r in enumerate ( seq ) :
seq_embd + = embds [ f " resp|NAR: { i } : { i + 1 } " ] ( r )
seqs . append ( torch . concat ( [ phn_embd , sep_embd ] ) )
seqs . append ( torch . concat ( [ lang_embd , sep_embd ] ) )
seqs . append ( torch . concat ( [ rvq_l_embd , sep_embd ] ) )
seqs . append ( torch . concat ( [ prom_embd , sep_embd ] ) )
if seq_embd is not None :
seqs . append ( seq_embd )
inputs = torch . concat ( seqs )
pos_ids = torch . tensor ( [ i for seq in seqs for i , _ in enumerate ( seq ) ] , device = device , dtype = torch . int32 )
attn_mask = torch . tensor ( [ True for seq in seqs for i , _ in enumerate ( seq ) ] , device = device , dtype = torch . bool )
return inputs , pos_ids , attn_mask
def generate ( phn , prom , sequence = [ ] , mode = " resp|AR:0:0 " , max_tokens = 75 * 4 , temperature = 1.0 ) :
lm_head = heads [ mode ]
model . _update_causal_mask = model . _original_update_causal_mask
n_outputs = 1
stop_token = 1024
if mode == " len " :
temperature = 0.0
max_tokens = 5
stop_token = 10
elif mode != " resp|AR:0:0 " :
temperature = 0.0
max_tokens = len ( sequence ) + 1
n_outputs = len ( sequence [ 0 ] )
model . _update_causal_mask = model . _update_noncausal_mask
while len ( sequence ) < max_tokens :
inputs , pos_ids , attn_mask = create_inputs ( phn , prom , seq = sequence , mode = mode . split ( " | " ) [ - 1 ] )
out = model ( inputs_embeds = inputs . unsqueeze ( 0 ) , position_ids = pos_ids . unsqueeze ( 0 ) , attention_mask = attn_mask . unsqueeze ( 0 ) )
logits = lm_head ( out [ 0 ] ) . float ( )
logits = logits [ 0 , - n_outputs : , : ]
t = Categorical ( logits = logits / temperature ) . sample ( ) if temperature > 0 else logits . argmax ( dim = - 1 )
if n_outputs > 1 :
sequence . append ( [ _ . item ( ) for _ in t ] )
break
else :
t = t [ 0 ]
if stop_token in t :
break
sequence . append ( t . item ( ) )
return sequence
# check embds
if False :
inputs , pos_ids , attn_mask = create_inputs ( phn , prom , mode = " len " )
flattened = [ sum ( embd ) . item ( ) for embd in inputs ]
for i , embd in enumerate ( flattened ) :
print ( f ' { i } : ' , pos_ids [ i ] . item ( ) , " \t " , embd )
# test len inferencing
print ( " len: " , generate ( phn , prom , mode = " len " ) )
# test ar ouptut
if resp :
resp = [ resp [ 0 ] ]
else :
resp = [ generate ( phn , prom ) ]
print ( " AR: " , resp )
# test nar ouptut
for i in range ( 1 , 8 ) :
resp = generate ( phn , prom , sequence = resp , mode = f " resp|NAR: { i - 1 } : { i } " )
print ( f " NAR: { i - 1 } : { i } : " , resp [ - 1 ] )
decode_to_file ( torch . tensor ( resp , dtype = torch . int16 , device = device ) . t ( ) , " ./data/test.wav " )