2023-08-02 21:53:35 +00:00
import math
import torch
import torch . nn . functional as F
import traceback
2023-09-13 02:28:07 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-08-02 21:53:35 +00:00
from typing import Literal , overload
from functools import partial
from einops import rearrange
from torch import Tensor , einsum , nn
2024-04-17 02:04:48 +00:00
from torch . nn import Embedding
2023-08-02 21:53:35 +00:00
from torch . distributions import Categorical
from torch . nn . utils . rnn import pad_sequence
from torch . utils . checkpoint import checkpoint
from torchmetrics . classification import BinaryAccuracy , MulticlassAccuracy , MulticlassPrecision
2023-10-11 17:25:31 +00:00
from . . samplers import reptition_penalize , length_penalize , top_k_top_p_filtering , dynamic_temperature , top_k_logits_list , mirostat_sample
2023-08-02 21:53:35 +00:00
2023-12-23 01:27:36 +00:00
try :
from . transformer import SinusoidalEmbedding , Block as TransformerBlock
except Exception as e :
print ( " Error importing `transformer` arch: " , e )
pass
try :
2024-04-16 15:02:31 +00:00
#from .retnet import RetNetDecoder, RetNetConfig
from . retnet_ts import RetNetDecoder , RetNetConfig
2023-12-23 01:27:36 +00:00
except Exception as e :
print ( " Error importing `retnet` arch: " , e )
pass
2024-04-09 01:14:51 +00:00
from . retnet_hf import RetNetDecoder as RetNetDecoder_HF , RetNetConfig as RetNetConfig_HF
"""
try :
except Exception as e :
print ( " Error importing `retnet-hf` arch: " , e )
pass
"""
2023-12-23 01:27:36 +00:00
try :
from transformers import LlamaModel , LlamaConfig
except Exception as e :
print ( " Error importing `llama` arch: " , e )
pass
2024-02-01 03:48:36 +00:00
try :
from transformers import MistralModel , MistralConfig
except Exception as e :
print ( " Error importing `mistral` arch: " , e )
pass
2024-03-01 02:29:17 +00:00
try :
2024-04-05 00:11:49 +00:00
from bitnet . bit_transformer import Transformer as BitNetTransformerBlock , RMSNorm as BitNetRMSNorm
2024-04-09 01:14:51 +00:00
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
2024-04-05 00:11:49 +00:00
class BitNetTransformer ( nn . Module ) :
def __init__ (
self ,
dim : int ,
depth : int ,
num_tokens : int ,
heads = 8 ,
ff_mult = 4 ,
) :
super ( ) . __init__ ( )
self . transformer = BitNetTransformerBlock ( dim = dim , depth = depth , heads = heads , ff_mult = ff_mult )
self . norm = BitNetRMSNorm ( dim )
2024-03-01 02:29:17 +00:00
2024-04-05 00:11:49 +00:00
def forward ( self , x ) :
x = self . transformer ( x )
return self . norm ( x )
"""
from bitnet import BitNetTransformer
2024-03-01 02:29:17 +00:00
def NoEmbedding_BitNetTransformer_Forward ( self , x ) :
x = self . transformer ( x )
return self . to_logits [ 0 ] ( x )
BitNetTransformer . forward = NoEmbedding_BitNetTransformer_Forward
2024-04-05 00:11:49 +00:00
"""
2024-03-01 02:29:17 +00:00
except Exception as e :
print ( " Error importing `bitnet` arch: " , e )
pass
2023-12-23 01:27:36 +00:00
try :
from transformers import MixtralModel , MixtralConfig
2024-01-27 01:41:12 +00:00
from transformers . models . mixtral . modeling_mixtral import load_balancing_loss_func , MixtralSparseMoeBlock
# This is required because batch sizes > 1 throws errors
def Fixed_MixtralSparseMoeBlock_forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
""" """
batch_size , sequence_length , hidden_dim = hidden_states . shape
hidden_states = hidden_states . reshape ( - 1 , hidden_dim ) # was view()
# router_logits: (batch * sequence_length, n_experts)
router_logits = self . gate ( hidden_states )
routing_weights = F . softmax ( router_logits , dim = 1 , dtype = torch . float )
routing_weights , selected_experts = torch . topk ( routing_weights , self . top_k , dim = - 1 )
routing_weights / = routing_weights . sum ( dim = - 1 , keepdim = True )
# we cast back to the input dtype
routing_weights = routing_weights . to ( hidden_states . dtype )
final_hidden_states = torch . zeros (
( batch_size * sequence_length , hidden_dim ) , dtype = hidden_states . dtype , device = hidden_states . device
)
expert_mask = torch . nn . functional . one_hot ( selected_experts , num_classes = self . num_experts ) . permute ( 2 , 1 , 0 )
for expert_idx in range ( self . num_experts ) :
expert_layer = self . experts [ expert_idx ]
idx , top_x = torch . where ( expert_mask [ expert_idx ] )
if top_x . shape [ 0 ] == 0 :
continue
top_x_list = top_x . tolist ( )
idx_list = idx . tolist ( )
current_state = hidden_states [ None , top_x_list ] . reshape ( - 1 , hidden_dim )
current_hidden_states = expert_layer ( current_state ) * routing_weights [ top_x_list , idx_list , None ]
final_hidden_states . index_add_ ( 0 , top_x , current_hidden_states . to ( hidden_states . dtype ) )
final_hidden_states = final_hidden_states . reshape ( batch_size , sequence_length , hidden_dim )
return final_hidden_states , router_logits
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock . forward
MixtralSparseMoeBlock . forward = Fixed_MixtralSparseMoeBlock_forward
2023-12-23 01:27:36 +00:00
except Exception as e :
print ( " Error importing `mixtral` arch: " , e )
2023-08-02 21:53:35 +00:00
def _create_mask ( l , device ) :
""" 1 is valid region and 0 is invalid. """
seq = torch . arange ( max ( l ) , device = device ) . unsqueeze ( 0 ) # (1 t)
stop = torch . tensor ( l , device = device ) . unsqueeze ( 1 ) # (b 1)
return ( seq < stop ) . float ( ) # (b t)
def _join ( x : tuple [ Tensor ] , sep : Tensor ) :
"""
Args :
x : ( k t d )
sep : ( d )
"""
ret = x [ 0 ]
for i in range ( 1 , len ( x ) ) :
ret = torch . cat ( ( ret , sep [ None ] , x [ i ] ) , dim = 0 )
return ret
def list_to_tensor ( x_list : list [ Tensor ] , pattern = " t b c -> b t c " ) :
"""
Args :
x_list : [ ( t d ) ]
Returns :
x : ( ? ? ? )
m : ( ? ? ? ) , same as x
"""
l = list ( map ( len , x_list ) )
x = rearrange ( pad_sequence ( x_list ) , pattern )
m = _create_mask ( l , x_list [ 0 ] . device )
m = m . t ( ) . unsqueeze ( - 1 ) # (t b 1)
m = rearrange ( m , pattern )
m = m . to ( x )
return x , m
2023-09-09 01:30:54 +00:00
# automagically parses a batch-list and returns it as a list
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
class Embedding ( nn . Embedding ) :
def forward ( self , x_list : list [ Tensor ] ) - > list [ Tensor ] :
if len ( x_list ) == 0 :
return [ ]
return super ( ) . forward ( torch . cat ( x_list ) ) . split ( [ * map ( len , x_list ) ] )
2024-04-17 02:04:48 +00:00
"""
2023-08-02 21:53:35 +00:00
2023-09-13 18:19:11 +00:00
class MultiEmbedding ( nn . Module ) :
2023-08-02 21:53:35 +00:00
"""
This embedding sums embeddings on different levels .
"""
2023-09-08 06:03:24 +00:00
def __init__ ( self , max_n_levels , n_tokens , token_dim , monolithic = False ) :
2023-09-16 05:26:13 +00:00
super ( ) . __init__ ( )
2023-09-08 20:36:26 +00:00
self . monolithic = monolithic
2023-08-02 21:53:35 +00:00
self . max_n_levels = max_n_levels
self . n_tokens = n_tokens
2023-09-08 20:36:26 +00:00
self . weight = nn . Parameter ( torch . randn ( max_n_levels , n_tokens , token_dim ) )
2023-08-02 21:53:35 +00:00
2023-09-07 22:08:38 +00:00
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
def forward ( self , x_list : list [ Tensor ] , quant_levels : Tensor | None = None ) - > list [ Tensor ] :
2023-08-02 21:53:35 +00:00
if len ( x_list ) == 0 :
return [ ]
2023-09-08 20:36:26 +00:00
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
2023-09-08 06:03:24 +00:00
if self . monolithic :
2023-09-08 20:36:26 +00:00
w = self . weight [ : 1 ] if quant_levels is None else self . weight [ 1 : ]
2023-09-08 06:03:24 +00:00
else :
w = self . weight
2023-08-02 21:53:35 +00:00
2023-09-11 19:13:42 +00:00
padded_x_list = [ ]
2023-08-02 21:53:35 +00:00
2023-09-08 20:36:26 +00:00
for i , xi in enumerate ( x_list ) :
2023-08-02 21:53:35 +00:00
xi = F . one_hot ( xi . to ( torch . int64 ) , num_classes = self . n_tokens ) # t l' k
2023-09-11 19:13:42 +00:00
wi = w . shape [ 0 ] - xi . shape [ 1 ]
xi = F . pad ( xi , ( 0 , 0 , 0 , wi ) ) # t l k
2023-08-02 21:53:35 +00:00
padded_x_list . append ( xi . to ( w ) )
2023-09-11 19:13:42 +00:00
x = torch . cat ( padded_x_list ) # n l k
x = einsum ( " l k d, n l k -> n d " , w , x )
2023-08-02 21:53:35 +00:00
x_list = x . split ( [ * map ( len , x_list ) ] )
2023-09-07 21:48:02 +00:00
2023-09-11 19:13:42 +00:00
return x_list
2023-09-07 21:48:02 +00:00
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
2023-09-11 19:13:42 +00:00
class AudioEmbedding ( nn . Module ) :
2024-01-25 18:18:48 +00:00
def __init__ ( self , l_tokens , token_dim , levels = None ) :
2023-09-07 14:14:03 +00:00
super ( ) . __init__ ( )
2023-09-21 00:10:59 +00:00
self . embeddings = nn . ModuleList ( [ nn . Embedding ( n_tokens , token_dim ) for n_tokens in l_tokens ] )
2024-01-25 18:18:48 +00:00
self . weight = nn . ParameterList ( [ nn . Parameter ( torch . Tensor ( [ 1 ] ) ) for i in range ( levels ) ] ) if levels is not None else None
2023-09-07 14:14:03 +00:00
2024-04-17 02:04:48 +00:00
def forward ( self , xi : Tensor , quant_levels : Tensor | None = None ) - > Tensor :
# prom
if quant_levels is None and xi . shape [ - 1 ] > 1 :
x = sum ( [ self . embeddings [ k ] ( xi [ : , k ] ) * ( self . weight [ k ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
# AR resp
elif quant_levels is None or quant_levels == 0 :
x = self . embeddings [ 0 ] ( xi [ : , 0 ] )
# NAR resp
else :
x = sum ( [ self . embeddings [ k + 1 ] ( xi [ : , k ] ) * ( self . weight [ k + 1 ] if self . weight is not None else 1 ) for k in range ( xi . shape [ - 1 ] ) ] )
return x
2023-09-07 21:48:02 +00:00
2023-08-02 21:53:35 +00:00
class Base ( nn . Module ) :
@property
def causal ( self ) - > bool :
raise NotImplementedError
@property
def arch_type ( self ) - > str :
raise NotImplementedError
@property
def norm_type ( self ) :
raise NotImplementedError
@property
def n_prom_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 20:06:33 +00:00
2023-09-06 23:58:35 +00:00
@property
def n_resp_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 20:06:33 +00:00
@property
def n_max_levels ( self ) - > int :
raise NotImplementedError
2023-08-19 01:58:07 +00:00
2023-09-21 00:10:59 +00:00
@property
def n_langs ( self ) - > int :
raise NotImplementedError
2024-04-16 00:54:32 +00:00
2023-08-19 01:58:07 +00:00
@property
def n_tasks ( self ) - > int :
raise NotImplementedError
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
@property
def n_tones ( self ) - > int :
raise NotImplementedError
2023-09-02 01:58:29 +00:00
@property
def recurrent_chunk_size ( self ) - > int :
raise NotImplementedError
2023-09-21 00:10:59 +00:00
@property
def rotary_embedding_base ( self ) - > float :
return 10000
2023-09-04 03:46:08 +00:00
@property
def interleave ( self ) - > bool :
return False
2023-09-02 01:58:29 +00:00
2023-09-07 00:33:39 +00:00
@property
2023-09-07 21:48:02 +00:00
def monolithic ( self ) - > bool :
2023-09-07 00:33:39 +00:00
return False
2023-09-07 14:14:03 +00:00
@property
2023-09-11 19:13:42 +00:00
def version ( self ) - > int :
return 1
2023-09-07 14:14:03 +00:00
2023-09-06 23:58:35 +00:00
@property
def stop_token ( self ) :
if not self . causal :
raise ValueError ( " Not using stop token! " )
return self . n_tokens
@property
def ignore_index ( self ) :
return - 100
2023-08-02 21:53:35 +00:00
def __init__ (
self ,
2023-08-19 01:58:07 +00:00
n_tokens : int = 1024 ,
2023-08-02 21:53:35 +00:00
d_model : int = 512 ,
n_heads : int = 8 ,
n_layers : int = 12 ,
p_dropout : float = 0.1 ,
2023-09-07 00:33:39 +00:00
2024-04-09 01:14:51 +00:00
n_experts : int = 1 ,
l_padding : int = 0 ,
2023-12-21 00:45:58 +00:00
2024-02-01 03:48:36 +00:00
training = True ,
2023-09-04 03:46:08 +00:00
config = None ,
2023-08-02 21:53:35 +00:00
) :
super ( ) . __init__ ( )
2024-02-01 03:48:36 +00:00
self . training = training
2023-09-04 03:46:08 +00:00
self . config = config
2023-09-05 20:38:21 +00:00
self . activation_checkpointing = self . config . activation_checkpointing if self . config is not None else True
2023-09-04 03:46:08 +00:00
2023-08-02 21:53:35 +00:00
self . n_tokens = n_tokens
self . d_model = d_model
self . n_heads = n_heads
self . n_layers = n_layers
2023-12-21 00:45:58 +00:00
self . n_experts = n_experts
2024-04-09 01:14:51 +00:00
self . l_padding = l_padding
2023-08-02 21:53:35 +00:00
# +1 to include the stop token
2023-09-21 00:10:59 +00:00
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
n_prom_tokens = n_tokens
2023-09-06 23:58:35 +00:00
n_resp_tokens = n_tokens + ( 1 if self . causal else 0 ) # AR requires a stop token to... know when to stop
2023-08-02 21:53:35 +00:00
self . text_emb = Embedding ( n_tokens , d_model )
2023-10-22 14:01:47 +00:00
self . langs_emb = None
2024-04-16 00:54:32 +00:00
self . tones_emb = None
2023-10-22 14:01:47 +00:00
self . tasks_emb = None
2023-09-07 00:33:39 +00:00
2023-09-11 19:13:42 +00:00
if self . version == 1 : # legacy
2023-09-21 00:10:59 +00:00
n_prom_tokens + = ( self . n_tasks - 1 ) # old models have the task tokens in the prom
2023-09-11 19:13:42 +00:00
self . proms_emb = MultiEmbedding ( self . n_prom_levels , n_prom_tokens , d_model )
self . resps_emb = MultiEmbedding ( self . n_resp_levels , n_resp_tokens , d_model , monolithic = self . monolithic )
else :
2023-09-21 00:10:59 +00:00
# [1024] * 8
2024-01-25 18:18:48 +00:00
self . proms_emb = AudioEmbedding ( [ n_prom_tokens ] * self . n_prom_levels , d_model , self . n_prom_levels if self . version > 3 else None )
2023-09-21 00:10:59 +00:00
# [1025] + [1024] * 8
2024-01-25 18:18:48 +00:00
self . resps_emb = AudioEmbedding ( [ n_resp_tokens ] + [ n_resp_tokens - 1 ] * ( self . n_resp_levels - 1 ) , d_model , self . n_resp_levels if self . version > 3 else None )
2023-09-21 00:10:59 +00:00
2023-10-22 14:01:47 +00:00
2023-10-12 01:38:40 +00:00
if self . version > = 3 :
2023-10-22 14:01:47 +00:00
self . langs_emb = Embedding ( self . n_langs , d_model ) if self . n_langs > 0 else None
self . tasks_emb = Embedding ( self . n_tasks , d_model ) if self . n_tasks > 0 else None
2024-04-16 00:54:32 +00:00
if self . version > = 4 :
self . tones_emb = Embedding ( self . n_tones , d_model ) if self . n_tones > 0 else None
2023-08-02 21:53:35 +00:00
self . sep = nn . Parameter ( torch . randn ( d_model ) )
2023-08-27 00:53:23 +00:00
2023-08-02 21:53:35 +00:00
if self . arch_type == " transformer " :
self . sin_emb = SinusoidalEmbedding ( d_model )
self . blocks = nn . ModuleList ( [ TransformerBlock (
d_model = d_model ,
n_heads = n_heads ,
2024-02-01 03:48:36 +00:00
p_dropout = p_dropout if training else 0.0 ,
2023-08-19 01:58:07 +00:00
causal = self . causal ,
2023-08-02 21:53:35 +00:00
norm_type = self . norm_type ,
n_levels = self . n_resp_levels ,
) for _ in range ( n_layers ) ] )
2024-03-01 02:29:17 +00:00
elif self . arch_type == " mistral " or self . arch_type == " mixtral " :
2024-02-01 03:48:36 +00:00
if n_experts < = 1 :
self . model = MistralModel ( MistralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
num_key_value_heads = n_heads ,
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-04-16 15:19:02 +00:00
attn_implementation = self . config . attention if self . config is not None else None , # "flash_attention_2",
2024-02-01 03:48:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
attention_dropout = p_dropout if training else 0.0 ,
num_key_value_heads = n_heads ,
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-04-16 15:19:02 +00:00
attn_implementation = self . config . attention if self . config is not None else None , # "flash_attention_2",
2024-02-01 03:48:36 +00:00
) )
2023-12-23 01:27:36 +00:00
elif self . arch_type == " llama " :
if n_experts < = 1 :
self . model = LlamaModel ( LlamaConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
2024-04-16 15:19:02 +00:00
attn_implementation = self . config . attention if self . config is not None else None , # "flash_attention_2",
2023-12-23 01:27:36 +00:00
) )
else :
self . model = MixtralModel ( MixtralConfig (
vocab_size = n_resp_tokens ,
hidden_size = d_model ,
max_position_embeddings = 75 * 60 , # max-length of 60 seconds
intermediate_size = d_model * 4 ,
num_hidden_layers = n_layers ,
num_attention_heads = n_heads ,
2024-02-01 03:48:36 +00:00
attention_dropout = p_dropout if training else 0.0 ,
2023-12-23 01:27:36 +00:00
num_key_value_heads = n_heads ,
2024-02-01 03:48:36 +00:00
sliding_window = 75 * 12 , # 12 second context window
output_router_logits = training ,
2023-12-23 01:27:36 +00:00
hidden_act = " gelu " ,
is_encoder_decoder = False ,
is_decoder = True ,
num_local_experts = n_experts ,
num_experts_per_tok = min ( 2 , n_experts ) ,
2024-04-16 15:19:02 +00:00
attn_implementation = self . config . attention if self . config is not None else None , # "flash_attention_2",
2023-12-23 01:27:36 +00:00
) )
2024-02-01 03:48:36 +00:00
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-12-26 03:20:32 +00:00
kwargs = dict (
2023-12-23 01:27:36 +00:00
vocab_size = n_resp_tokens ,
2023-08-02 21:53:35 +00:00
decoder_embed_dim = d_model ,
2023-10-05 21:39:46 +00:00
decoder_value_embed_dim = d_model * 2 ,
2023-08-02 21:53:35 +00:00
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
2024-02-01 03:48:36 +00:00
dropout = p_dropout if training else 0.0 ,
2023-09-05 20:38:21 +00:00
checkpoint_activations = self . activation_checkpointing ,
2023-10-05 21:39:46 +00:00
activation_fn = " gelu " ,
2023-10-07 01:08:28 +00:00
use_layernorm = True , # self.version < 3,
use_biases = True , # self.version < 3,
use_glu = False , # self.version >= 3,
2023-08-02 21:53:35 +00:00
2023-09-02 01:58:29 +00:00
chunkwise_recurrent = self . causal and self . recurrent_chunk_size > 0 ,
recurrent_chunkwise_size = self . recurrent_chunk_size if self . causal else 0 ,
2023-08-02 21:53:35 +00:00
no_output_layer = True ,
decoder_normalize_before = True ,
2023-09-21 00:10:59 +00:00
rotary_embedding_base = self . rotary_embedding_base , # 10000
2023-12-26 03:20:32 +00:00
)
if n_experts > 1 :
kwargs . update ( dict (
use_xmoe = True ,
moe_freq = 1 ,
moe_expert_count = n_experts ,
moe_gating_use_fp32 = False ,
) )
2023-12-21 00:45:58 +00:00
2023-12-26 03:20:32 +00:00
self . model = RetNetDecoder ( RetNetConfig ( * * kwargs ) )
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
kwargs = dict (
vocab_size = n_resp_tokens ,
decoder_embed_dim = d_model ,
decoder_value_embed_dim = d_model * 2 ,
decoder_retention_heads = n_heads ,
decoder_ffn_embed_dim = d_model * 4 ,
decoder_layers = n_layers ,
dropout = p_dropout if training else 0.0 ,
checkpoint_activations = self . activation_checkpointing ,
activation_fn = " gelu " ,
use_glu = False , # self.version >= 3,
recurrent_chunk_size = self . recurrent_chunk_size if self . causal else 0 ,
decoder_normalize_before = True ,
deepnorm = False ,
subln = True ,
)
self . model = RetNetDecoder_HF ( RetNetConfig_HF ( * * kwargs ) )
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
self . model = BitNetTransformer (
num_tokens = n_resp_tokens ,
dim = d_model ,
depth = n_layers ,
heads = n_heads ,
ff_mult = 4 ,
)
else :
raise RuntimeError ( f ' Unknown arch specified: { self . arch_type } ' )
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
self . classifier = nn . Linear ( d_model , n_resp_tokens )
self . accuracy_metric = MulticlassAccuracy (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
self . precision_metric = MulticlassPrecision (
n_resp_tokens ,
top_k = 10 ,
average = " micro " ,
multidim_average = " global " ,
ignore_index = self . ignore_index ,
)
2024-04-16 00:54:32 +00:00
def _forward (
2023-08-02 21:53:35 +00:00
self ,
2024-04-16 00:54:32 +00:00
inputs ,
mask = None ,
state = None ,
2023-08-02 21:53:35 +00:00
) :
2024-04-16 00:54:32 +00:00
x = inputs
m = mask . squeeze ( - 1 ) . int ( )
2023-12-23 01:27:36 +00:00
aux_loss = None
2024-04-14 18:12:50 +00:00
"""
# Broken
if state is not None and ( self . arch_type == " retnet " or self . arch_type == " retnet-hf " ) :
2023-09-02 01:58:29 +00:00
# prefill
if len ( state ) == 0 :
2023-09-05 20:38:21 +00:00
prefill_size = x . shape [ 1 ]
# run the initial prompt to fill the KV cache
2024-04-09 01:14:51 +00:00
if self . arch_type == " retnet " :
for n in range ( prefill_size ) :
xi = x [ : , n , : ] . unsqueeze ( 1 )
self . model ( xi , incremental_state = state , token_embeddings = xi , features_only = True )
elif self . arch_type == " retnet-hf " :
2024-04-14 18:12:50 +00:00
state = None
2024-04-09 01:14:51 +00:00
for n in range ( prefill_size ) :
xi = x [ : , n , : ] . unsqueeze ( 1 )
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2024-04-14 18:12:50 +00:00
inputs_embeds = xi ,
past_key_values = state ,
use_cache = True ,
forward_impl = ' recurrent ' ,
2024-04-09 01:14:51 +00:00
# return_dict=True,
)
out = self . model ( * * kwargs )
2024-04-14 18:12:50 +00:00
state = out . past_key_values
2023-09-02 01:58:29 +00:00
# grab last token(s)
x = x [ : , - 1 , : ] . unsqueeze ( 1 )
2024-04-14 18:12:50 +00:00
"""
2023-12-23 01:27:36 +00:00
# HF transformer derived model
2024-04-16 00:54:32 +00:00
if self . arch_type in [ " llama " , " mistral " , " mixtral " ] :
2023-12-23 01:27:36 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2023-12-23 01:27:36 +00:00
inputs_embeds = x ,
2024-02-01 03:48:36 +00:00
past_key_values = state ,
2024-04-14 18:12:50 +00:00
use_cache = True ,
2024-02-01 03:48:36 +00:00
# return_dict=True,
2023-12-23 01:27:36 +00:00
)
2024-02-01 03:48:36 +00:00
if self . n_experts > 1 and targ_list is not None :
2023-12-23 01:27:36 +00:00
kwargs [ " output_router_logits " ] = True
2023-09-02 01:58:29 +00:00
2023-12-23 01:27:36 +00:00
t = self . model ( * * kwargs )
2024-02-01 03:48:36 +00:00
2023-12-23 01:27:36 +00:00
x = t [ 0 ]
2024-02-01 03:48:36 +00:00
if state is not None :
state = t [ 1 ]
if self . n_experts > 1 and targ_list is not None :
2023-12-23 01:27:36 +00:00
router_logits = t [ - 1 ]
aux_loss = self . model . config . router_aux_loss_coef * load_balancing_loss_func ( router_logits , self . model . config . num_local_experts , self . model . config . num_experts_per_tok )
elif self . arch_type == " transformer " :
2023-09-12 21:04:45 +00:00
# ensures we specify a quant_level for the transformer implementation's AdaLN
2023-09-07 14:14:03 +00:00
l = torch . zeros ( ( batch_size , ) , dtype = torch . int32 ) if quant_levels is None else quant_levels
l = l . to ( device )
2023-09-12 21:04:45 +00:00
# inject position information
x = self . sin_emb . add_pe ( x )
# pass our inputs through the transformer
2023-08-02 21:53:35 +00:00
for block in self . blocks :
2024-04-16 00:54:32 +00:00
x = block ( x , m , l )
2023-08-02 21:53:35 +00:00
elif self . arch_type == " retnet " :
2023-09-12 21:04:45 +00:00
# pass our inputs through the RetNet
2023-12-23 01:27:36 +00:00
x , _ = self . model ( x , incremental_state = state , token_embeddings = x , features_only = True )
2023-12-26 03:20:32 +00:00
if _ is not None and " l_aux " in _ and self . n_experts > 1 :
2023-12-23 01:27:36 +00:00
aux_loss = torch . sum ( torch . stack ( [ t for t in _ [ " l_aux " ] if t is not None ] ) ) * 0.001
2024-04-09 01:14:51 +00:00
elif self . arch_type == " retnet-hf " :
2024-04-14 18:12:50 +00:00
first = state is None or len ( state ) == 0
2024-04-09 01:14:51 +00:00
kwargs = dict (
2024-04-16 00:54:32 +00:00
attention_mask = m ,
2024-04-14 18:12:50 +00:00
inputs_embeds = x if first else x [ : , - 1 , : ] . unsqueeze ( 1 ) ,
past_key_values = None if first else state ,
use_cache = True ,
forward_impl = ' parallel ' if first else ' recurrent ' ,
return_dict = True ,
2024-04-09 01:14:51 +00:00
)
2024-04-14 18:12:50 +00:00
out = self . model ( * * kwargs )
x = out . last_hidden_state
2024-04-09 01:14:51 +00:00
if state is not None :
2024-04-14 18:12:50 +00:00
state = out . past_key_values
2024-03-01 02:29:17 +00:00
elif self . arch_type == " bitnet " :
x = self . model ( x )
2024-04-14 18:12:50 +00:00
2023-09-12 21:04:45 +00:00
# output projection layer with masking
2024-04-16 00:54:32 +00:00
x = self . classifier ( x ) * mask
2024-04-14 18:12:50 +00:00
2024-04-16 00:54:32 +00:00
return x , state , aux_loss
2024-04-17 02:04:48 +00:00
def inputs (
2024-04-16 00:54:32 +00:00
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
targ_list : list [ Tensor ] | None = None ,
2024-04-17 02:04:48 +00:00
2024-04-16 00:54:32 +00:00
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
2024-04-17 02:04:48 +00:00
inputs = [ [ ] for _ in range ( batch_size ) ]
for i in range ( batch_size ) :
if text_list is not None :
inputs [ i ] . append ( ( " text " , text_list [ i ] ) )
if proms_list is not None :
inputs [ i ] . append ( ( " prom " , proms_list [ i ] ) )
if resps_list is not None :
inputs [ i ] . append ( ( " resp " , resps_list [ i ] ) )
if targ_list is not None :
inputs [ i ] . append ( ( " targ " , targ_list [ i ] ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return inputs
def inputs_to_embeddings (
self ,
inputs : list ,
quant_levels : Tensor | None = None
) :
x_list = [ ]
for b_i in range ( len ( inputs ) ) :
batch = [ ]
for i in range ( len ( inputs [ b_i ] ) ) :
name , input = inputs [ b_i ] [ i ]
embedding = None
if name == " text " :
embedding = self . text_emb ( input )
elif name == " lang " :
embedding = self . langs_emb ( input )
elif name == " prom " :
embedding = self . proms_emb ( input )
elif name == " tone " :
embedding = self . tones_emb ( input )
elif name == " resp " :
embedding = self . resps_emb ( input , quant_levels [ b_i ] if quant_levels is not None else None )
else :
continue
batch . append ( embedding )
x_list . append ( _join ( batch , self . sep ) )
2024-04-16 00:54:32 +00:00
2024-04-17 02:04:48 +00:00
return x_list
def training_targets (
self ,
inputs : list ,
) :
x_list = [ ]
for bi in range ( len ( inputs ) ) :
batch = [ ]
for i in range ( len ( inputs [ bi ] ) ) :
name , input = inputs [ bi ] [ i ]
device = input . device
if name == " prom " :
batch . append ( torch . full_like ( input [ . . . , 0 ] , self . ignore_index ) )
elif name in [ " text " , " lang " , " tone " , " targ " ] :
batch . append ( input )
x_list . append ( _join ( batch , torch . tensor ( self . ignore_index , device = device ) ) )
return x_list
def forward (
self ,
inputs : list ,
quant_levels : Tensor | None = None ,
state : dict | list | None = None ,
) :
x_list = self . inputs_to_embeddings ( inputs , quant_levels )
2024-04-16 00:54:32 +00:00
x , m = list_to_tensor ( x_list )
2024-04-17 02:04:48 +00:00
# yes, there's a better way.
training = False
for b_i in range ( len ( inputs ) ) :
for i in range ( len ( inputs [ b_i ] ) ) :
name , input = inputs [ b_i ] [ i ]
if name == " targ " :
training = True
device = x . device
batch_size = len ( x_list )
2024-04-16 00:54:32 +00:00
# pad our input and mask, but retain the original length by doing it after
if self . l_padding and x . shape [ 1 ] % self . l_padding != 0 :
# pad input
shape = list ( x . shape )
shape [ 1 ] = self . l_padding - shape [ 1 ] % self . l_padding
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
x = torch . cat ( [ x , padding ] , dim = 1 )
# pad mask
shape [ 2 ] = 1
padding = torch . zeros ( shape , dtype = x . dtype , device = x . device )
m = torch . cat ( [ m , padding ] , dim = 1 )
x , state , aux_loss = self . _forward (
inputs = x ,
mask = m ,
state = state ,
)
2023-08-02 21:53:35 +00:00
# Remove padding
2023-09-09 01:30:54 +00:00
logits = [ hi [ : li ] for hi , li in zip ( x , map ( len , x_list ) ) ]
2024-04-14 18:12:50 +00:00
2023-08-02 21:53:35 +00:00
# compute loss if the target is given
2024-04-17 02:04:48 +00:00
if training :
target_list = self . training_targets ( inputs )
2023-10-13 23:22:21 +00:00
# modify only for the AR so it can properly behave like a transformer
for i in range ( len ( target_list ) ) :
if quant_levels is not None and quant_levels [ i ] > 0 :
continue
2023-09-16 00:08:44 +00:00
2023-10-13 23:22:21 +00:00
logits [ i ] = logits [ i ] [ . . . , : - 1 , : ] # shift the target so that token n...
target_list [ i ] = target_list [ i ] [ . . . , 1 : ] # predicts token n + 1
2023-08-02 21:53:35 +00:00
2023-10-13 23:22:21 +00:00
target = torch . cat ( target_list )
2023-09-09 01:30:54 +00:00
inputs = torch . cat ( logits )
2023-08-02 21:53:35 +00:00
self . loss = dict (
2023-09-05 20:38:21 +00:00
# "nll" was in the original implementation and should actually just be called something else
2023-09-09 01:30:54 +00:00
nll = F . cross_entropy ( inputs , target , ignore_index = self . ignore_index )
2023-08-02 21:53:35 +00:00
)
2023-08-05 20:25:41 +00:00
self . stats = dict (
2023-09-09 01:30:54 +00:00
acc = self . accuracy_metric ( inputs , target ) ,
2024-01-27 01:41:12 +00:00
# precision = self.precision_metric( inputs, target ),
2023-08-05 20:25:41 +00:00
)
2023-12-23 01:27:36 +00:00
if aux_loss is not None :
self . loss [ " nll " ] + = aux_loss
2023-09-09 01:30:54 +00:00
2024-02-01 03:48:36 +00:00
return ( logits , state ) if state is not None else logits
2023-09-13 02:28:07 +00:00
def sample (
self ,
logits : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
quant_levels : Tensor | None = None ,
temperature : float = 1.0 ,
2023-10-10 22:02:33 +00:00
min_temperature : float = - 1.0 ,
2023-09-13 02:28:07 +00:00
top_k : int = - 100 ,
top_p : float = 1.0 ,
repetition_penalty : float = 1.0 ,
repetition_penalty_decay : float = 0.0 ,
length_penalty : float = 0.0 ,
2023-09-09 01:30:54 +00:00
2023-09-13 02:28:07 +00:00
beam_width : int = 0 ,
2023-09-18 23:55:41 +00:00
mirostat : list [ dict ] | None = None ,
2023-09-13 02:28:07 +00:00
) :
2023-10-10 22:02:33 +00:00
if min_temperature < 0 :
min_temperature = temperature
2023-09-09 01:30:54 +00:00
# (NAR) return the entire generated response
if quant_levels is not None :
logits = [ logit [ - l : ] for logit , l in zip ( logits , map ( len , resps_list ) ) ]
# (AR chunkwise) return the last chunkwise piece
2023-09-02 01:58:29 +00:00
elif self . causal and self . recurrent_chunk_size > 0 :
2023-09-09 01:30:54 +00:00
logits = [ logit [ - l : ] for logit , l in zip ( logits , self . recurrent_chunk_size ) ]
# (AR) return just the last code
2023-08-02 21:53:35 +00:00
else :
2023-09-09 01:30:54 +00:00
logits = [ logit [ - 1 : ] for logit in logits ]
2023-10-09 19:46:17 +00:00
devices = [ logit . device for logit in logits ]
2023-10-13 03:21:43 +00:00
logits = [ logit . to ( device = " cpu " , dtype = logit . dtype if logit . dtype != torch . float16 else torch . float32 ) for logit in logits ]
2023-10-09 19:46:17 +00:00
2023-09-09 01:30:54 +00:00
# perform repetition penalizing
2023-09-13 02:28:07 +00:00
logits = [ reptition_penalize ( logit , previous = resps [ : , - 1 ] , factor = repetition_penalty , decay = repetition_penalty_decay ) for logit , resps in zip ( logits , resps_list ) ]
2023-09-09 01:30:54 +00:00
2023-09-09 01:43:36 +00:00
# (AR) perform length penalizing
2023-09-09 01:30:54 +00:00
if quant_levels is None and self . causal :
2023-09-13 02:28:07 +00:00
logits = [ length_penalize ( logit , length = l + 1 , factor = length_penalty , token = self . stop_token ) for logit , l in zip ( logits , map ( len , resps_list ) ) ]
2023-09-09 01:30:54 +00:00
# perform top_k/top_p filtering of our logits
2023-09-13 02:28:07 +00:00
if top_k > 0 or top_p < 1.0 :
logits = [ top_k_top_p_filtering ( logit , top_k = top_k , top_p = top_p ) for logit in logits ]
2023-10-10 22:02:33 +00:00
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
2024-01-27 01:41:12 +00:00
# epsilon float comparison because I don't trust Python
2023-10-10 22:02:33 +00:00
if abs ( temperature - min_temperature ) > = 0.001 :
logits = [ dynamic_temperature ( logit , temperature = temperature , min_temperature = min_temperature ) for logit in logits ]
2023-10-09 18:01:40 +00:00
else :
logits = [ logit / temperature for logit in logits ]
2023-09-18 23:55:41 +00:00
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None :
# mirostat sampling
return [ mirostat_sample ( logit , state = state ) for logit , state in zip ( logits , mirostat ) ]
2023-09-13 02:28:07 +00:00
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
2023-09-13 18:19:11 +00:00
# returns the logit scores as well to be P-concatted with the previous scores
2023-09-13 02:28:07 +00:00
# to-do: not naively implement beam searching
if beam_width > 1 :
2023-09-13 18:19:11 +00:00
candidates = top_k_logits_list ( logits , beam_width )
2023-10-11 17:25:31 +00:00
res = [ torch . tensor ( token , dtype = torch . int16 ) . unsqueeze ( dim = - 1 ) for batch , token in candidates ]
scores = [ logits [ batch ] . flatten ( ) [ token ] for batch , token in candidates ]
2023-09-13 18:19:11 +00:00
return res , scores
2023-09-13 02:28:07 +00:00
2023-09-09 01:30:54 +00:00
# and sample
2023-10-11 17:25:31 +00:00
return [ Categorical ( logits = logit ) . sample ( ) for logit in logits ]
2023-08-02 21:53:35 +00:00
def example_usage ( ) :
2023-08-04 01:26:36 +00:00
from . . config import cfg
cfg . trainer . backend = " local "
2023-08-04 01:36:19 +00:00
cfg . trainer . check_for_oom = False
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
from functools import partial
from einops import repeat
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
from . . emb . qnt import decode_to_file
2023-08-04 01:26:36 +00:00
from . . engines import Engine , Engines
from tqdm import tqdm , trange
2023-09-07 22:08:38 +00:00
from . . utils import wrapper as ml
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
from . ar import AR
from . nar import NAR
2023-08-27 03:00:43 +00:00
device = " cuda "
2024-04-16 00:54:32 +00:00
x8 = partial ( repeat , pattern = " t -> t l " , l = cfg . model . prom_levels )
2023-08-02 21:53:35 +00:00
symmap = { ' <s> ' : 1 , ' </s> ' : 2 , ' ' : 3 , ' . ' : 4 , ' , ' : 5 , ' ! ' : 6 , ' ? ' : 7 , ' p ' : 7 , ' iː ' : 8 , ' ɚ ' : 9 , ' ˌ ' : 10 , ' dˌ ' : 11 , ' mˌ ' : 12 , ' d ' : 13 , ' ɹ ' : 14 , ' tˈ ' : 15 , ' pˌ ' : 16 , ' uː ' : 17 , ' l ' : 18 , ' æ ' : 19 , ' ɛ ' : 20 , ' ɪ ' : 21 , ' j ' : 22 , ' ʊ ' : 23 , ' t ' : 24 , ' n ' : 25 , ' v ' : 26 , ' a ' : 27 , ' o ' : 28 , ' ŋ ' : 29 , ' w ' : 30 , ' ʌ ' : 31 , ' hˈ ' : 32 , ' ɡ ˈ ' : 33 , ' ə ' : 34 , ' θˈ ' : 35 , ' dˈ ' : 36 , ' wˌ ' : 37 , ' h ' : 38 , ' z ' : 39 , ' k ' : 40 , ' ð ' : 41 , ' ɡˌ ' : 42 , ' ˈ ' : 43 , ' fˈ ' : 44 , ' i ' : 45 , ' s ' : 46 , ' ʃ ' : 47 , ' wˈ ' : 48 , ' ðˈ ' : 49 , ' ɹˈ ' : 50 , ' lˈ ' : 51 , ' ɡ ' : 52 , ' oː ' : 53 , ' mˈ ' : 54 , ' e ' : 55 , ' ɑ ː ' : 56 , ' nˈ ' : 57 , ' m ' : 58 , ' θˌ ' : 59 , ' sˈ ' : 60 , ' f ' : 61 , ' ɔː ' : 62 , ' hˌ ' : 63 , ' b ' : 64 , ' jˈ ' : 65 , ' ɐ ' : 66 , ' ʒˈ ' : 67 , ' θ ' : 68 , ' bˈ ' : 69 , ' ɾ ' : 70 , ' ɜː ' : 71 , ' ʌˈ ' : 72 , ' ʃˌ ' : 73 , ' bˌ ' : 74 , ' kˈ ' : 75 , ' ɔ ' : 76 , ' zˈ ' : 77 , ' ᵻ ' : 78 , ' kˌ ' : 79 , ' vˈ ' : 80 , ' fˌ ' : 81 , ' ʒ ' : 82 , ' ʃˈ ' : 83 , ' ɹˌ ' : 84 , ' tˌ ' : 85 , ' pˈ ' : 86 , ' ðˌ ' : 87 , ' sˌ ' : 88 , ' nˌ ' : 89 , ' lˌ ' : 90 , ' ̩ ' : 91 , ' ʔ ' : 92 , ' vˌ ' : 93 , ' ɪ ˈ ' : 94 , ' " ' : 95 , ' ɪˌ ' : 96 , ' ʒˌ ' : 97 , ' uː ˌ ' : 98 , ' ʊˈ ' : 99 , ' jˌ ' : 100 , ' uː ˈ ' : 101 , ' iː ˈ ' : 102 , ' zˌ ' : 103 , ' .ˈ ' : 104 , ' … ' : 105 , ' ŋˌ ' : 106 , ' ɐˌ ' : 107 , ' —ˈ ' : 108 , ' iˌ ' : 109 , ' iː ˌ ' : 110 , ' ɛː ' : 111 , ' ) ' : 112 , ' )ˈ ' : 113 , ' ( ' : 114 , ' u ' : 115 , ' - ' : 116 , ' ɖˈ ' : 117 , ' iˈ ' : 118 , ' ʰˈ ' : 119 , ' ɟˈ ' : 120 , ' ̃ ' : 121 , ' eː ' : 122 , ' ɾˈ ' : 123 , ' r ' : 124 , ' ʰ ' : 125 , ' -ˌ ' : 126 , ' ɫ ' : 127 , ' q ' : 128 , ' — ' : 129 , ' ʊˌ ' : 130 , ' aː ' : 131 , ' cˈ ' : 132 , ' …ˈ ' : 133 , ' c ' : 134 , ' ɳ ' : 135 , ' ɐˈ ' : 136 , ' x ' : 137 , ' ʔˌ ' : 138 , ' .ˌ ' : 139 , ' ɑ ' : 140 , ' ?ˈ ' : 141 , ' ̩ˈ ' : 142 , ' " ˈ ' : 143 , ' ,ˈ ' : 144 , ' ŋˈ ' : 145 , ' əˌ ' : 146 , ' !ˈ ' : 147 , ' " ˌ ' : 148 , ' ?ˌ ' : 149 , ' ,ˌ ' : 150 , ' —ˌ ' : 151 , ' ̩ˌ ' : 152 , ' əˈ ' : 153 , ' !ˌ ' : 154 , ' ɬ ' : 155 , ' ʲ ' : 156 , ' ¡ ' : 157 , ' ɯ ' : 158 , ' qˌ ' : 159 , ' ʑ ' : 160 , ' ʑˈ ' : 161 , ' ¿ ' : 162 , ' ɑ ː ˈ ' : 163 , ' iː ː ' : 164 , ' ɛˈ ' : 165 , ' ¡ˈ ' : 166 , ' æˈ ' : 167 , ' ç ' : 168 , ' ɾˌ ' : 169 , ' ᵻˈ ' : 170 , ' xˈ ' : 171 , ' ɔːˈ ' : 172 , ' ; ' : 173 , ' ɬˌ ' : 174 , ' : ' : 175 , ' ʔ ˈ ' : 176 , ' ɑːˌ ' : 177 , ' ɬˈ ' : 178 }
def tokenize ( content , lang_marker = " en " ) :
split = content . split ( " " )
phones = [ f " <s> " ] + [ " " if not p else p for p in split ] + [ f " </s> " ]
return torch . tensor ( [ * map ( symmap . get , phones ) ] ) . to ( )
kwargs = {
2023-09-02 01:58:29 +00:00
' n_tokens ' : 1024 ,
' d_model ' : 1024 ,
' n_heads ' : 16 ,
2023-09-02 02:33:51 +00:00
' n_layers ' : 12 ,
2023-08-02 21:53:35 +00:00
}
2023-08-04 01:26:36 +00:00
models = { " ar " : AR ( * * kwargs ) . to ( device ) , " nar " : NAR ( * * kwargs ) . to ( device ) }
2023-09-02 02:33:51 +00:00
for name , model in models . items ( ) :
print ( f " { name } parameter count: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) } " )
2023-09-07 22:08:38 +00:00
engines = Engines ( { name : Engine ( model = model , optimizer = ml . AdamW ( model . parameters ( ) , lr = 1e-4 ) ) for name , model in models . items ( ) } )
2023-08-02 21:53:35 +00:00
train = True
2024-04-16 00:54:32 +00:00
qnt = torch . load ( " data/qnt.pt " ) [ 0 ] . t ( ) [ : , : cfg . model . prom_levels ] . to ( device )
2023-08-04 01:26:36 +00:00
text_list = [
tokenize ( " ˈ a ɪ w ɪ l nˌ ɑ ː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m" ) . to ( device ) ,
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑ ː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑ ː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
]
2023-08-02 21:53:35 +00:00
2023-08-04 01:26:36 +00:00
proms_list = [
qnt . to ( device ) ,
]
resps_list = [
qnt . to ( device ) ,
]
2023-09-07 22:08:38 +00:00
def sample ( name , steps = 600 ) :
2023-08-04 01:26:36 +00:00
AR = None
NAR = None
engines . eval ( )
for name , engine in engines . items ( ) :
if name [ : 2 ] == " ar " :
AR = engine
elif name [ : 3 ] == " nar " :
NAR = engine
resps_list = AR ( text_list , proms_list , max_steps = steps , sampling_temperature = 1.0 )
resps_list = [ r . unsqueeze ( - 1 ) for r in resps_list ]
codes = NAR ( text_list , proms_list , resps_list = resps_list , sampling_temperature = 0.2 )
decode_to_file ( resps_list [ 0 ] , f " ./data/ar. { name } .wav " , device = device )
decode_to_file ( codes [ 0 ] , f " ./data/ar+nar. { name } .wav " , device = device )
if train :
sample ( " init " , 15 )
engines . train ( )
2023-09-07 22:08:38 +00:00
t = trange ( 500 )
2023-08-04 01:26:36 +00:00
for i in t :
stats = { " step " : i }
2023-09-02 02:33:51 +00:00
"""
2023-08-04 01:26:36 +00:00
for name , engine in engines . items ( ) :
stats | = engine . traverse ( text_list = text_list , proms_list = proms_list , resps_list = resps_list )
"""
2023-08-27 03:00:43 +00:00
stats = engines . step ( { " text_list " : text_list , " proms_list " : proms_list , " resps_list " : resps_list } )
2023-09-02 02:33:51 +00:00
tqdm . write ( f " { stats } " )
2023-08-04 01:26:36 +00:00
else :
for name , engine in engines . items ( ) :
engine . module . load_state_dict ( torch . load ( f " ./data/ { name } .pth " ) )
2023-08-02 21:53:35 +00:00
2023-08-04 01:26:36 +00:00
sample ( " final " )
2023-08-02 21:53:35 +00:00
if __name__ == " __main__ " :
example_usage ( )