2023-09-06 23:58:35 +00:00
from . base import Base , list_to_tensor , Categorical
2023-09-07 01:33:16 +00:00
from . . config import cfg
2023-09-06 23:58:35 +00:00
import torch
from torch . nn . utils . rnn import pad_sequence
import random
2023-09-13 18:19:11 +00:00
import math
2023-09-06 23:58:35 +00:00
from einops import rearrange
from torch import Tensor
from tqdm import trange
2023-10-10 03:03:58 +00:00
from . . emb . qnt import trim
2023-09-06 23:58:35 +00:00
class AR_NAR ( Base ) :
@property
def causal ( self ) :
return True
@property
def norm_type ( self ) :
2023-09-07 14:14:03 +00:00
return " ln " # if self.n_resp_levels == 1 else "adaln"
2023-09-06 23:58:35 +00:00
@property
def arch_type ( self ) - > str :
if hasattr ( self , " config " ) and self . config :
return self . config . arch_type
2024-04-16 00:54:32 +00:00
return cfg . model . arch_type
2023-09-06 23:58:35 +00:00
@property
def n_prom_levels ( self ) - > int :
2024-04-16 00:54:32 +00:00
return cfg . model . prom_levels
2023-09-06 23:58:35 +00:00
@property
def n_resp_levels ( self ) - > int :
if hasattr ( self , " config " ) and self . config :
return self . config . resp_levels
2024-04-16 00:54:32 +00:00
return cfg . model . resp_levels
2023-09-06 23:58:35 +00:00
@property
def n_max_levels ( self ) - > int :
2024-04-16 00:54:32 +00:00
return cfg . model . max_levels
2023-09-06 23:58:35 +00:00
@property
def n_tasks ( self ) - > int :
2024-04-16 00:54:32 +00:00
return cfg . model . tasks
2023-10-12 01:38:40 +00:00
@property
def n_langs ( self ) - > int :
2024-04-16 00:54:32 +00:00
return cfg . model . langs
@property
def n_tones ( self ) - > int :
return cfg . model . tones
2023-09-06 23:58:35 +00:00
@property
def recurrent_chunk_size ( self ) - > int :
2023-09-07 00:33:39 +00:00
return 0
2023-09-06 23:58:35 +00:00
2023-09-21 00:10:59 +00:00
"""
@property
def rotary_embedding_base ( self ) - > float :
if hasattr ( self , " config " ) and self . config :
return self . config . rotary_embedding_base
2024-04-16 00:54:32 +00:00
return cfg . model . rotary_embedding_base
2023-09-21 00:10:59 +00:00
"""
2023-09-06 23:58:35 +00:00
@property
def interleave ( self ) - > bool :
return False
2023-09-07 00:33:39 +00:00
@property
2023-09-07 21:48:02 +00:00
def monolithic ( self ) - > bool :
2023-09-07 00:33:39 +00:00
return True
2023-09-06 23:58:35 +00:00
2023-09-11 19:13:42 +00:00
@property
def version ( self ) - > int :
if hasattr ( self , " config " ) and self . config :
return self . config . version
2024-04-16 00:54:32 +00:00
return cfg . model . version
2023-09-11 19:13:42 +00:00
2023-09-06 23:58:35 +00:00
def _prune ( self , l : Tensor ) :
indices = ( l == self . stop_token ) . nonzero ( )
if len ( indices ) == 0 :
return l
return l [ : indices . min ( ) . item ( ) ]
@staticmethod
def _unsqueeze_list ( x_list , axis = - 1 ) :
return [ x . unsqueeze ( dim = axis ) for x in x_list ]
def forward (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
lang_list : list [ Tensor ] | None = None ,
2024-04-16 00:54:32 +00:00
tone_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
2023-09-06 23:58:35 +00:00
max_steps : int = 1000 ,
2024-01-25 18:18:48 +00:00
max_levels : int = 0 ,
2023-10-12 01:38:40 +00:00
max_resp_context : int = - 1 ,
2023-10-10 22:02:33 +00:00
sampling_temperature : float = 1.0 ,
sampling_min_temperature : float = - 1.0 ,
2023-09-09 01:30:54 +00:00
sampling_top_k : int = - 100 ,
sampling_top_p : float = 1.0 ,
sampling_repetition_penalty : float = 1.0 ,
2023-09-09 02:02:00 +00:00
sampling_repetition_penalty_decay : float = 0.0 ,
2023-09-09 01:30:54 +00:00
sampling_length_penalty : float = 0.0 ,
2023-09-13 02:28:07 +00:00
sampling_beam_width : int = 0 ,
2023-09-18 23:55:41 +00:00
sampling_mirostat_tau : float = 0.0 ,
sampling_mirostat_eta : float = 0.1 ,
2023-09-06 23:58:35 +00:00
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
# is training or NAR
if resps_list is not None :
n_levels_set = { r . shape [ - 1 ] for r in resps_list }
n_levels = next ( iter ( n_levels_set ) )
# is training
if n_levels == self . n_resp_levels :
2023-10-13 23:22:21 +00:00
# might be better to have this decided on the dataloader level
2023-12-21 00:45:58 +00:00
2024-02-01 03:48:36 +00:00
if cfg . experimental and False :
2023-12-21 00:45:58 +00:00
# makes higher levels less likely
def generate ( lo = 0 , hi = 8 ) :
index = lo
p = random . random ( )
for i in range ( lo , hi ) :
if p < 1.0 / ( 2 * * i ) :
index = i
return int ( index )
quant_levels = torch . Tensor ( [ generate ( 0 , self . n_resp_levels ) for _ in range ( batch_size ) ] ) . to ( dtype = torch . int16 )
2023-10-02 21:52:42 +00:00
else :
2023-12-23 01:27:36 +00:00
quant_levels = torch . randint ( 0 , self . n_resp_levels , ( batch_size , ) ) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
"""
2024-04-16 00:54:32 +00:00
if cfg . model . p_ar_level == " auto " or cfg . model . p_ar_level is None :
2023-12-21 00:45:58 +00:00
quant_levels = torch . randint ( 0 , self . n_resp_levels , ( batch_size , ) ) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else :
2024-04-16 00:54:32 +00:00
quant_levels = torch . Tensor ( [ 0 if random . random ( ) < cfg . model . p_ar_level else random . randint ( 1 , self . n_resp_levels ) for _ in range ( batch_size ) ] )
2023-12-23 01:27:36 +00:00
"""
2023-10-02 21:52:42 +00:00
2023-09-12 21:04:45 +00:00
targ_list = [ r [ . . . , l ] for r , l in zip ( resps_list , quant_levels ) ] # ensures we only have 1 RVQ-bin (our target)
2023-10-13 23:22:21 +00:00
resps_list = [ r if l == 0 else r [ . . . , : l ] for r , l in zip ( resps_list , quant_levels ) ] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
2023-10-10 03:03:58 +00:00
2023-10-22 14:06:59 +00:00
"""
2023-10-10 03:03:58 +00:00
if cfg . experimental :
proms_list = [ r if l == 0 else trim ( r , 75 * 3 ) for r , l in zip ( proms_list , quant_levels ) ] # trim input prompt to 3 seconds
2023-10-22 14:06:59 +00:00
"""
# append stop tokens for AR
for i in range ( batch_size ) :
if quant_levels [ i ] > 0 :
continue
2023-10-13 23:22:21 +00:00
2023-10-22 14:06:59 +00:00
resps_list [ i ] = torch . cat ( [ resps_list [ i ] , torch . Tensor ( [ [ self . stop_token ] * n_levels ] ) . to ( device = device , dtype = torch . int16 ) ] )
targ_list [ i ] = torch . cat ( [ targ_list [ i ] , torch . Tensor ( [ self . stop_token ] ) . to ( device = device , dtype = torch . int16 ) ] )
2023-10-13 23:22:21 +00:00
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
2023-09-06 23:58:35 +00:00
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
targ_list = targ_list ,
2023-10-12 02:21:50 +00:00
lang_list = lang_list ,
2024-04-17 02:04:48 +00:00
tone_list = tone_list
)
return super ( ) . forward (
inputs = inputs ,
2023-09-06 23:58:35 +00:00
quant_levels = quant_levels ,
)
# is NAR
2023-09-11 01:33:33 +00:00
if max_levels == 0 :
2024-01-25 18:18:48 +00:00
max_levels = self . n_resp_levels - 1
2023-10-17 00:30:38 +00:00
prev_list = resps_list
2023-09-06 23:58:35 +00:00
2023-10-17 00:30:38 +00:00
for n in trange ( max_levels , desc = " NAR " ) :
2023-09-08 20:36:26 +00:00
level = prev_list [ 0 ] . shape [ - 1 ]
2023-09-10 18:50:13 +00:00
if level > = max_levels + 1 : # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
2023-09-06 23:58:35 +00:00
break
2023-10-11 17:25:31 +00:00
quant_levels = torch . full ( ( len ( text_list ) , ) , level )
2023-09-06 23:58:35 +00:00
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
2023-09-13 02:28:07 +00:00
text_list = text_list ,
proms_list = proms_list ,
resps_list = prev_list ,
2023-10-12 02:21:50 +00:00
lang_list = lang_list ,
2024-04-16 00:54:32 +00:00
tone_list = tone_list ,
2024-04-17 02:04:48 +00:00
)
logits = super ( ) . forward (
inputs = inputs ,
2023-09-13 02:28:07 +00:00
quant_levels = quant_levels ,
)
resps_list = super ( ) . sample (
logits = logits ,
resps_list = prev_list ,
2023-09-06 23:58:35 +00:00
quant_levels = quant_levels ,
2023-09-13 02:28:07 +00:00
temperature = sampling_temperature ,
2023-10-10 22:02:33 +00:00
min_temperature = sampling_min_temperature ,
2023-09-13 02:28:07 +00:00
top_p = sampling_top_p ,
top_k = sampling_top_k ,
repetition_penalty = sampling_repetition_penalty ,
repetition_penalty_decay = sampling_repetition_penalty_decay ,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
2023-09-18 23:55:41 +00:00
#mirostat=mirostat,
2023-09-06 23:58:35 +00:00
)
2023-10-11 17:25:31 +00:00
prev_list = [ torch . cat ( [ rs , r . unsqueeze ( - 1 ) . to ( device ) ] , dim = - 1 ) for rs , r in zip ( prev_list , resps_list ) ]
2023-09-06 23:58:35 +00:00
return prev_list
# is AR
2023-09-13 02:28:07 +00:00
sequence_list = [ torch . zeros ( 0 , device = device ) . to ( torch . int16 ) for _ in text_list ]
2023-09-06 23:58:35 +00:00
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-04-14 18:12:50 +00:00
recurrent_state = [ ] if cfg . inference . recurrent_forward else None
2023-09-18 23:55:41 +00:00
mirostat = [
{ " n " : 1024 , " tau " : sampling_mirostat_tau , " eta " : sampling_mirostat_eta , " max_surprise " : sampling_mirostat_eta * 2 , " error_surprise " : 0 , " running_total_surprise " : 0 }
] * batch_size if sampling_mirostat_tau > 0.0 else None
2023-09-06 23:58:35 +00:00
2023-09-13 18:19:11 +00:00
scores = [ 1.0 ] * sampling_beam_width
2023-09-06 23:58:35 +00:00
if self . interleave :
max_steps * = self . n_prom_levels
2023-09-13 18:19:11 +00:00
# get next in sequence
2023-10-17 00:30:38 +00:00
for n in trange ( max_steps / / max ( 1 , self . recurrent_chunk_size ) , desc = " AR " ) :
2023-10-12 01:38:40 +00:00
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
if max_resp_context > 0 :
resps_list = self . _unsqueeze_list ( [ sequence [ - max_resp_context : ] for sequence in sequence_list ] )
else :
resps_list = self . _unsqueeze_list ( sequence_list )
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
)
2024-02-01 03:48:36 +00:00
if recurrent_state is not None :
logits , recurrent_state = super ( ) . forward (
2024-04-17 02:04:48 +00:00
inputs = inputs ,
state = recurrent_state ,
2024-02-01 03:48:36 +00:00
)
else :
logits = super ( ) . forward (
2024-04-17 02:04:48 +00:00
inputs = inputs ,
state = recurrent_state ,
2024-02-01 03:48:36 +00:00
)
2023-09-06 23:58:35 +00:00
2023-09-13 02:28:07 +00:00
r = super ( ) . sample (
logits = logits ,
resps_list = resps_list ,
temperature = sampling_temperature ,
2023-10-10 22:02:33 +00:00
min_temperature = sampling_min_temperature ,
2023-09-13 02:28:07 +00:00
top_p = sampling_top_p ,
top_k = sampling_top_k ,
repetition_penalty = sampling_repetition_penalty ,
repetition_penalty_decay = sampling_repetition_penalty_decay ,
length_penalty = sampling_length_penalty ,
beam_width = sampling_beam_width ,
2023-09-18 23:55:41 +00:00
mirostat = mirostat ,
2023-09-13 02:28:07 +00:00
)
2023-09-18 23:55:41 +00:00
if mirostat is not None :
# r is the state
mirostat = r
# extract token from state
r = [ state [ " token " ] for state in mirostat ]
2023-09-13 02:28:07 +00:00
# we do it here because the sampler will already expand our logits list
2023-09-18 23:55:41 +00:00
elif sampling_beam_width > 0 :
2023-09-13 18:19:11 +00:00
# expand tuple
r , s = r
# first step, expand batch
if batch_size == 1 :
2023-10-17 00:30:38 +00:00
batch_size = sampling_beam_width
2023-09-13 18:19:11 +00:00
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2023-10-17 00:30:38 +00:00
scores = [ scores [ i ] + score for i , score in enumerate ( s ) ]
2023-09-13 02:28:07 +00:00
2023-09-06 23:58:35 +00:00
# append tokens
for i , ri in enumerate ( r ) :
if self . stop_token in ri :
stopped [ i ] = True
2023-10-11 17:25:31 +00:00
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , ri . to ( device ) ] )
2023-09-06 23:58:35 +00:00
# stop token found
stopped | = r == self . stop_token
if stopped . all ( ) . item ( ) :
break
2023-09-13 18:19:11 +00:00
# pick the best scoring candidate
# desu this is always going to be candidate 0
2023-10-17 00:30:38 +00:00
if sampling_beam_width :
sequence_list = [ sequence_list [ 0 ] ]
2023-09-13 02:28:07 +00:00
return [ self . _prune ( r ) for r in sequence_list ]
2023-09-06 23:58:35 +00:00
def example_usage ( ) :
2024-04-09 01:14:51 +00:00
#cfg.trainer.backend = "local"
2024-04-21 19:58:04 +00:00
cfg . hyperparameters . gradient_accumulation_steps = 1
2023-09-06 23:58:35 +00:00
2024-04-21 19:58:04 +00:00
from functools import partial
2023-09-06 23:58:35 +00:00
from einops import repeat
2024-04-21 19:58:04 +00:00
from tqdm import tqdm
2023-09-06 23:58:35 +00:00
2023-09-07 01:33:16 +00:00
from . . emb . qnt import decode_to_file , unload_model
2023-09-06 23:58:35 +00:00
from . . engines import Engine
2023-09-07 22:08:38 +00:00
from . . utils import wrapper as ml
2024-04-21 19:58:04 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-09-06 23:58:35 +00:00
device = " cuda "
2024-04-16 00:54:32 +00:00
x8 = partial ( repeat , pattern = " t -> t l " , l = cfg . model . prom_levels )
2023-09-06 23:58:35 +00:00
2024-04-21 19:58:04 +00:00
def tokenize ( content ) :
return torch . tensor ( cfg . tokenizer . encode ( content ) )
2024-04-18 01:39:35 +00:00
2024-04-21 19:58:04 +00:00
def _load_quants ( path ) - > Tensor :
if cfg . inference . audio_backend == " dac " :
qnt = np . load ( f ' { path } .dac ' , allow_pickle = True ) [ ( ) ]
2024-04-29 03:28:29 +00:00
return torch . from_numpy ( qnt [ " codes " ] . astype ( np . int16 ) ) [ 0 , : cfg . model . prom_levels , : ] . t ( ) . to ( torch . int16 )
2024-04-21 19:58:04 +00:00
return torch . load ( f ' { path } .pt ' ) [ 0 ] [ : , : cfg . model . prom_levels ] . t ( ) . to ( torch . int16 )
qnt = _load_quants ( " ./data/qnt " )
2023-09-06 23:58:35 +00:00
2023-10-13 23:22:21 +00:00
2023-09-06 23:58:35 +00:00
text_list = [
2024-04-21 19:58:04 +00:00
tokenize ( " ˈ aɪ wɪ l nˌɑ ː t ˈ æsk ɐ sˈ ɛkənd tˈ aɪ m" ) . to ( device ) ,
2023-09-06 23:58:35 +00:00
]
proms_list = [
2024-04-29 03:28:29 +00:00
qnt . to ( device ) ,
2023-09-06 23:58:35 +00:00
]
resps_list = [
qnt . to ( device ) ,
]
text_list = text_list [ : 1 ]
proms_list = proms_list [ : 1 ]
resps_list = resps_list [ : 1 ]
2024-03-02 02:38:06 +00:00
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
2023-09-06 23:58:35 +00:00
kwargs = {
' n_tokens ' : 1024 ,
2023-12-21 00:45:58 +00:00
' d_model ' : 1024 , # 256, # 1024, # 1536
' n_heads ' : 16 , # 4, # 16, # 24
2024-04-16 15:02:31 +00:00
' n_layers ' : 12 , # 32
2024-03-02 01:20:10 +00:00
' n_experts ' : 1 ,
2024-04-09 01:14:51 +00:00
2024-04-10 03:04:01 +00:00
' l_padding ' : 8 if cfg . fp8 . enabled else 0 ,
2023-12-21 00:45:58 +00:00
}
"""
kwargs = {
' n_tokens ' : 1024 ,
' d_model ' : 256 ,
' n_heads ' : 4 ,
' n_layers ' : 12 ,
2023-12-23 01:27:36 +00:00
' n_experts ' : 8 ,
2023-09-06 23:58:35 +00:00
}
2024-03-02 02:38:06 +00:00
"""
2023-09-06 23:58:35 +00:00
"""
try :
2024-04-16 00:54:32 +00:00
kwargs [ ' config ' ] = cfg . model
2023-09-06 23:58:35 +00:00
except Exception as e :
pass
"""
model = AR_NAR ( * * kwargs ) . to ( device )
2024-04-18 18:32:41 +00:00
steps = 750
2024-03-02 01:20:10 +00:00
optimizer = ml . Prodigy ( model . parameters ( ) , lr = 1.0 )
2024-04-10 03:04:01 +00:00
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
2024-03-02 01:20:10 +00:00
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
2024-04-09 01:14:51 +00:00
2023-09-07 01:33:16 +00:00
engine = Engine ( model = model , optimizer = optimizer )
2023-09-07 14:14:03 +00:00
2024-04-05 00:11:49 +00:00
# copy embeddings if requested
2024-04-09 19:41:13 +00:00
"""
2024-04-16 00:54:32 +00:00
if cfg . model . _embeddings is not None :
embeddings_path = cfg . relpath / cfg . model . _embeddings
2024-04-05 00:11:49 +00:00
if embeddings_path . exists ( ) :
embeddings = torch . load ( embeddings_path , map_location = torch . device ( cfg . device ) )
if " module " in embeddings :
embeddings = embeddings [ " module " ]
frozen_params = set ( )
for k in list ( embeddings . keys ( ) ) :
2024-04-29 03:28:29 +00:00
if re . findall ( r ' _emb. ' , k ) :
2024-04-05 00:11:49 +00:00
frozen_params . add ( k )
else :
del embeddings [ k ]
engine . module . load_state_dict ( embeddings , strict = False )
for name , param in engine . module . named_parameters ( ) :
if name not in frozen_params :
continue
param . requires_grad_ ( False )
engine . _frozen_params . add ( param )
2024-04-09 19:41:13 +00:00
"""
2024-04-05 00:11:49 +00:00
2024-04-09 19:41:13 +00:00
if ( cfg . bitsandbytes . enabled and cfg . bitsandbytes . replace ) or ( cfg . fp8 . enabled ) :
model . model = ml . replace_linear ( model . model )
2024-03-02 01:20:10 +00:00
2023-09-21 00:10:59 +00:00
torch . save ( {
' module ' : model . state_dict ( )
} , " ./data/test.pth " )
2023-09-07 14:14:03 +00:00
print ( f " AR+NAR parameter count: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) } " )
2024-04-09 01:14:51 +00:00
2023-09-09 01:30:54 +00:00
@torch.inference_mode ( )
2024-04-18 18:32:41 +00:00
def sample ( name , steps = 1000 ) :
2024-04-18 01:39:35 +00:00
if cfg . inference . audio_backend == " dac " and name == " init " :
return
2023-09-06 23:58:35 +00:00
engine . eval ( )
2023-10-10 22:02:33 +00:00
resps_list = engine ( text_list , proms_list , max_steps = steps , sampling_temperature = 0.95 )
2023-09-06 23:58:35 +00:00
2024-04-18 01:39:35 +00:00
if cfg . inference . audio_backend != " dac " :
for i , o in enumerate ( resps_list ) :
_ = decode_to_file ( o , f " data/ar. { i } . { name } .wav " , device = device )
2023-09-06 23:58:35 +00:00
resps_list = [ r . unsqueeze ( - 1 ) for r in resps_list ]
resps_list = engine ( text_list , proms_list , resps_list = resps_list , sampling_temperature = 0.2 )
for i , o in enumerate ( resps_list ) :
_ = decode_to_file ( o , f " data/ar+nar. { i } . { name } .wav " , device = device )
2023-09-07 01:33:16 +00:00
unload_model ( )
2023-09-06 23:58:35 +00:00
def train ( ) :
engine . train ( )
2023-09-07 22:08:38 +00:00
t = trange ( steps )
2023-09-06 23:58:35 +00:00
for i in t :
stats = { " step " : i }
stats | = engine . traverse ( text_list = text_list , proms_list = proms_list , resps_list = resps_list )
2024-03-02 02:18:43 +00:00
stats | = { " grad_norm " : engine . get_global_grad_norm ( ) }
2023-09-06 23:58:35 +00:00
tqdm . write ( f " { stats } " )
2023-12-23 01:27:36 +00:00
torch . save ( {
' module ' : model . state_dict ( )
} , " ./data/test.pth " )
2023-10-09 18:01:40 +00:00
sample ( " init " , 5 )
2023-09-06 23:58:35 +00:00
train ( )
sample ( " final " )
if __name__ == " __main__ " :
example_usage ( )