2023-09-06 23:58:35 +00:00
from . base import Base , list_to_tensor , Categorical
2023-09-07 01:33:16 +00:00
from . . config import cfg
2023-09-06 23:58:35 +00:00
import torch
from torch . nn . utils . rnn import pad_sequence
import random
2023-09-13 18:19:11 +00:00
import math
2023-09-06 23:58:35 +00:00
from einops import rearrange
from torch import Tensor
from tqdm import trange
2023-10-10 03:03:58 +00:00
from . . emb . qnt import trim
2023-09-06 23:58:35 +00:00
class AR_NAR ( Base ) :
@property
def causal ( self ) :
return True
@property
def norm_type ( self ) :
2023-09-07 14:14:03 +00:00
return " ln " # if self.n_resp_levels == 1 else "adaln"
2023-09-06 23:58:35 +00:00
@property
def arch_type ( self ) - > str :
if hasattr ( self , " config " ) and self . config :
return self . config . arch_type
return cfg . models . ar_nar . arch_type
@property
def n_prom_levels ( self ) - > int :
return cfg . models . prom_levels
@property
def n_resp_levels ( self ) - > int :
if hasattr ( self , " config " ) and self . config :
return self . config . resp_levels
return cfg . models . ar_nar . resp_levels
@property
def n_max_levels ( self ) - > int :
return cfg . models . max_levels
@property
def n_tasks ( self ) - > int :
2023-10-12 01:38:40 +00:00
return cfg . models . ar_nar . tasks
@property
def n_langs ( self ) - > int :
return cfg . models . ar_nar . langs
2023-09-06 23:58:35 +00:00
@property
def recurrent_chunk_size ( self ) - > int :
2023-09-07 00:33:39 +00:00
return 0
2023-09-06 23:58:35 +00:00
2023-09-21 00:10:59 +00:00
"""
@property
def rotary_embedding_base ( self ) - > float :
if hasattr ( self , " config " ) and self . config :
return self . config . rotary_embedding_base
return cfg . models . ar_nar . rotary_embedding_base
"""
2023-09-06 23:58:35 +00:00
@property
def interleave ( self ) - > bool :
return False
2023-09-07 00:33:39 +00:00
@property
2023-09-07 21:48:02 +00:00
def monolithic ( self ) - > bool :
2023-09-07 00:33:39 +00:00
return True
2023-09-06 23:58:35 +00:00
2023-09-11 19:13:42 +00:00
@property
def version ( self ) - > int :
if hasattr ( self , " config " ) and self . config :
return self . config . version
return cfg . models . ar_nar . version
2023-09-06 23:58:35 +00:00
def _prune ( self , l : Tensor ) :
indices = ( l == self . stop_token ) . nonzero ( )
if len ( indices ) == 0 :
return l
return l [ : indices . min ( ) . item ( ) ]
@staticmethod
def _unsqueeze_list ( x_list , axis = - 1 ) :
return [ x . unsqueeze ( dim = axis ) for x in x_list ]
def forward (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
lang_list : list [ Tensor ] | None = None ,
2023-09-06 23:58:35 +00:00
max_steps : int = 1000 ,
2024-01-25 18:18:48 +00:00
max_levels : int = 0 ,
2023-10-12 01:38:40 +00:00
max_resp_context : int = - 1 ,
2023-10-10 22:02:33 +00:00
sampling_temperature : float = 1.0 ,
sampling_min_temperature : float = - 1.0 ,
2023-09-09 01:30:54 +00:00
sampling_top_k : int = - 100 ,
sampling_top_p : float = 1.0 ,
sampling_repetition_penalty : float = 1.0 ,
2023-09-09 02:02:00 +00:00
sampling_repetition_penalty_decay : float = 0.0 ,
2023-09-09 01:30:54 +00:00
sampling_length_penalty : float = 0.0 ,
2023-09-13 02:28:07 +00:00
sampling_beam_width : int = 0 ,
2023-09-18 23:55:41 +00:00
sampling_mirostat_tau : float = 0.0 ,
sampling_mirostat_eta : float = 0.1 ,
2023-09-06 23:58:35 +00:00
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
# is training or NAR
if resps_list is not None :
n_levels_set = { r . shape [ - 1 ] for r in resps_list }
n_levels = next ( iter ( n_levels_set ) )
# is training
if n_levels == self . n_resp_levels :
2023-10-13 23:22:21 +00:00
# might be better to have this decided on the dataloader level
2023-12-21 00:45:58 +00:00
2024-02-01 03:48:36 +00:00
if cfg . experimental and False :
2023-12-21 00:45:58 +00:00
# makes higher levels less likely
def generate ( lo = 0 , hi = 8 ) :
index = lo
p = random . random ( )
for i in range ( lo , hi ) :
if p < 1.0 / ( 2 * * i ) :
index = i
return int ( index )
quant_levels = torch . Tensor ( [ generate ( 0 , self . n_resp_levels ) for _ in range ( batch_size ) ] ) . to ( dtype = torch . int16 )
2023-10-02 21:52:42 +00:00
else :
2023-12-23 01:27:36 +00:00
quant_levels = torch . randint ( 0 , self . n_resp_levels , ( batch_size , ) ) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
"""
2023-12-21 00:45:58 +00:00
if cfg . models . ar_nar . p_ar_level == " auto " or cfg . models . ar_nar . p_ar_level is None :
quant_levels = torch . randint ( 0 , self . n_resp_levels , ( batch_size , ) ) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else :
quant_levels = torch . Tensor ( [ 0 if random . random ( ) < cfg . models . ar_nar . p_ar_level else random . randint ( 1 , self . n_resp_levels ) for _ in range ( batch_size ) ] )
2023-12-23 01:27:36 +00:00
"""
2023-10-02 21:52:42 +00:00
2023-09-12 21:04:45 +00:00
targ_list = [ r [ . . . , l ] for r , l in zip ( resps_list , quant_levels ) ] # ensures we only have 1 RVQ-bin (our target)
2023-10-13 23:22:21 +00:00
resps_list = [ r if l == 0 else r [ . . . , : l ] for r , l in zip ( resps_list , quant_levels ) ] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
2023-10-10 03:03:58 +00:00
2023-10-22 14:06:59 +00:00
"""
2023-10-10 03:03:58 +00:00
if cfg . experimental :
proms_list = [ r if l == 0 else trim ( r , 75 * 3 ) for r , l in zip ( proms_list , quant_levels ) ] # trim input prompt to 3 seconds
2023-10-22 14:06:59 +00:00
"""
# append stop tokens for AR
for i in range ( batch_size ) :
if quant_levels [ i ] > 0 :
continue
2023-10-13 23:22:21 +00:00
2023-10-22 14:06:59 +00:00
resps_list [ i ] = torch . cat ( [ resps_list [ i ] , torch . Tensor ( [ [ self . stop_token ] * n_levels ] ) . to ( device = device , dtype = torch . int16 ) ] )
targ_list [ i ] = torch . cat ( [ targ_list [ i ] , torch . Tensor ( [ self . stop_token ] ) . to ( device = device , dtype = torch . int16 ) ] )
2023-10-13 23:22:21 +00:00
2023-09-06 23:58:35 +00:00
return super ( ) . forward (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
targ_list = targ_list ,
2023-10-12 02:21:50 +00:00
lang_list = lang_list ,
2023-09-06 23:58:35 +00:00
quant_levels = quant_levels ,
)
# is NAR
2023-09-11 01:33:33 +00:00
if max_levels == 0 :
2024-01-25 18:18:48 +00:00
max_levels = self . n_resp_levels - 1
2023-10-17 00:30:38 +00:00
prev_list = resps_list
2023-09-06 23:58:35 +00:00
2023-10-17 00:30:38 +00:00
for n in trange ( max_levels , desc = " NAR " ) :
2023-09-08 20:36:26 +00:00
level = prev_list [ 0 ] . shape [ - 1 ]
2023-09-10 18:50:13 +00:00
if level > = max_levels + 1 : # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
2023-09-06 23:58:35 +00:00
break
2023-10-11 17:25:31 +00:00
quant_levels = torch . full ( ( len ( text_list ) , ) , level )
2023-09-06 23:58:35 +00:00
2023-09-13 02:28:07 +00:00
logits = super ( ) . forward (
text_list = text_list ,
proms_list = proms_list ,
resps_list = prev_list ,
2023-10-12 02:21:50 +00:00
lang_list = lang_list ,
2023-09-13 02:28:07 +00:00
quant_levels = quant_levels ,
)
resps_list = super ( ) . sample (
logits = logits ,
resps_list = prev_list ,
2023-09-06 23:58:35 +00:00
quant_levels = quant_levels ,
2023-09-13 02:28:07 +00:00
temperature = sampling_temperature ,
2023-10-10 22:02:33 +00:00
min_temperature = sampling_min_temperature ,
2023-09-13 02:28:07 +00:00
top_p = sampling_top_p ,
top_k = sampling_top_k ,
repetition_penalty = sampling_repetition_penalty ,
repetition_penalty_decay = sampling_repetition_penalty_decay ,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
2023-09-18 23:55:41 +00:00
#mirostat=mirostat,
2023-09-06 23:58:35 +00:00
)
2023-10-11 17:25:31 +00:00
prev_list = [ torch . cat ( [ rs , r . unsqueeze ( - 1 ) . to ( device ) ] , dim = - 1 ) for rs , r in zip ( prev_list , resps_list ) ]
2023-09-06 23:58:35 +00:00
return prev_list
# is AR
2023-09-13 02:28:07 +00:00
sequence_list = [ torch . zeros ( 0 , device = device ) . to ( torch . int16 ) for _ in text_list ]
2023-09-06 23:58:35 +00:00
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2023-09-18 23:55:41 +00:00
recurrent_state = { } if cfg . inference . recurrent_forward else None
mirostat = [
{ " n " : 1024 , " tau " : sampling_mirostat_tau , " eta " : sampling_mirostat_eta , " max_surprise " : sampling_mirostat_eta * 2 , " error_surprise " : 0 , " running_total_surprise " : 0 }
] * batch_size if sampling_mirostat_tau > 0.0 else None
2023-09-06 23:58:35 +00:00
2023-09-13 18:19:11 +00:00
scores = [ 1.0 ] * sampling_beam_width
2023-09-06 23:58:35 +00:00
if self . interleave :
max_steps * = self . n_prom_levels
2023-09-13 18:19:11 +00:00
# get next in sequence
2023-10-17 00:30:38 +00:00
for n in trange ( max_steps / / max ( 1 , self . recurrent_chunk_size ) , desc = " AR " ) :
2023-10-12 01:38:40 +00:00
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
if max_resp_context > 0 :
resps_list = self . _unsqueeze_list ( [ sequence [ - max_resp_context : ] for sequence in sequence_list ] )
else :
resps_list = self . _unsqueeze_list ( sequence_list )
2024-02-01 03:48:36 +00:00
if recurrent_state is not None :
logits , recurrent_state = super ( ) . forward (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
state = recurrent_state
)
else :
logits = super ( ) . forward (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
state = recurrent_state
)
2023-09-06 23:58:35 +00:00
2023-09-13 02:28:07 +00:00
r = super ( ) . sample (
logits = logits ,
resps_list = resps_list ,
temperature = sampling_temperature ,
2023-10-10 22:02:33 +00:00
min_temperature = sampling_min_temperature ,
2023-09-13 02:28:07 +00:00
top_p = sampling_top_p ,
top_k = sampling_top_k ,
repetition_penalty = sampling_repetition_penalty ,
repetition_penalty_decay = sampling_repetition_penalty_decay ,
length_penalty = sampling_length_penalty ,
beam_width = sampling_beam_width ,
2023-09-18 23:55:41 +00:00
mirostat = mirostat ,
2023-09-13 02:28:07 +00:00
)
2023-09-18 23:55:41 +00:00
if mirostat is not None :
# r is the state
mirostat = r
# extract token from state
r = [ state [ " token " ] for state in mirostat ]
2023-09-13 02:28:07 +00:00
# we do it here because the sampler will already expand our logits list
2023-09-18 23:55:41 +00:00
elif sampling_beam_width > 0 :
2023-09-13 18:19:11 +00:00
# expand tuple
r , s = r
# first step, expand batch
if batch_size == 1 :
2023-10-17 00:30:38 +00:00
batch_size = sampling_beam_width
2023-09-13 18:19:11 +00:00
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2023-10-17 00:30:38 +00:00
scores = [ scores [ i ] + score for i , score in enumerate ( s ) ]
2023-09-13 02:28:07 +00:00
2023-09-06 23:58:35 +00:00
# append tokens
for i , ri in enumerate ( r ) :
if self . stop_token in ri :
stopped [ i ] = True
2023-10-11 17:25:31 +00:00
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , ri . to ( device ) ] )
2023-09-06 23:58:35 +00:00
# stop token found
stopped | = r == self . stop_token
if stopped . all ( ) . item ( ) :
break
2023-09-13 18:19:11 +00:00
# pick the best scoring candidate
# desu this is always going to be candidate 0
2023-10-17 00:30:38 +00:00
if sampling_beam_width :
sequence_list = [ sequence_list [ 0 ] ]
2023-09-13 02:28:07 +00:00
return [ self . _prune ( r ) for r in sequence_list ]
2023-09-06 23:58:35 +00:00
def example_usage ( ) :
2024-04-09 01:14:51 +00:00
#cfg.trainer.backend = "local"
2023-09-06 23:58:35 +00:00
from functools import partial
from einops import repeat
2023-09-07 01:33:16 +00:00
from . . emb . qnt import decode_to_file , unload_model
2023-09-06 23:58:35 +00:00
from . . engines import Engine
from tqdm import tqdm
2023-09-07 22:08:38 +00:00
from . . utils import wrapper as ml
2024-04-05 00:11:49 +00:00
import re
2023-09-06 23:58:35 +00:00
device = " cuda "
x8 = partial ( repeat , pattern = " t -> t l " , l = cfg . models . prom_levels )
symmap = { ' <s> ' : 1 , ' </s> ' : 2 , ' ' : 3 , ' . ' : 4 , ' , ' : 5 , ' ! ' : 6 , ' ? ' : 7 , ' p ' : 7 , ' iː ' : 8 , ' ɚ ' : 9 , ' ˌ ' : 10 , ' dˌ ' : 11 , ' mˌ ' : 12 , ' d ' : 13 , ' ɹ ' : 14 , ' tˈ ' : 15 , ' pˌ ' : 16 , ' uː ' : 17 , ' l ' : 18 , ' æ ' : 19 , ' ɛ ' : 20 , ' ɪ ' : 21 , ' j ' : 22 , ' ʊ ' : 23 , ' t ' : 24 , ' n ' : 25 , ' v ' : 26 , ' a ' : 27 , ' o ' : 28 , ' ŋ ' : 29 , ' w ' : 30 , ' ʌ ' : 31 , ' hˈ ' : 32 , ' ɡ ˈ ' : 33 , ' ə ' : 34 , ' θˈ ' : 35 , ' dˈ ' : 36 , ' wˌ ' : 37 , ' h ' : 38 , ' z ' : 39 , ' k ' : 40 , ' ð ' : 41 , ' ɡˌ ' : 42 , ' ˈ ' : 43 , ' fˈ ' : 44 , ' i ' : 45 , ' s ' : 46 , ' ʃ ' : 47 , ' wˈ ' : 48 , ' ðˈ ' : 49 , ' ɹˈ ' : 50 , ' lˈ ' : 51 , ' ɡ ' : 52 , ' oː ' : 53 , ' mˈ ' : 54 , ' e ' : 55 , ' ɑ ː ' : 56 , ' nˈ ' : 57 , ' m ' : 58 , ' θˌ ' : 59 , ' sˈ ' : 60 , ' f ' : 61 , ' ɔː ' : 62 , ' hˌ ' : 63 , ' b ' : 64 , ' jˈ ' : 65 , ' ɐ ' : 66 , ' ʒˈ ' : 67 , ' θ ' : 68 , ' bˈ ' : 69 , ' ɾ ' : 70 , ' ɜː ' : 71 , ' ʌˈ ' : 72 , ' ʃˌ ' : 73 , ' bˌ ' : 74 , ' kˈ ' : 75 , ' ɔ ' : 76 , ' zˈ ' : 77 , ' ᵻ ' : 78 , ' kˌ ' : 79 , ' vˈ ' : 80 , ' fˌ ' : 81 , ' ʒ ' : 82 , ' ʃˈ ' : 83 , ' ɹˌ ' : 84 , ' tˌ ' : 85 , ' pˈ ' : 86 , ' ðˌ ' : 87 , ' sˌ ' : 88 , ' nˌ ' : 89 , ' lˌ ' : 90 , ' ̩ ' : 91 , ' ʔ ' : 92 , ' vˌ ' : 93 , ' ɪ ˈ ' : 94 , ' " ' : 95 , ' ɪˌ ' : 96 , ' ʒˌ ' : 97 , ' uː ˌ ' : 98 , ' ʊˈ ' : 99 , ' jˌ ' : 100 , ' uː ˈ ' : 101 , ' iː ˈ ' : 102 , ' zˌ ' : 103 , ' .ˈ ' : 104 , ' … ' : 105 , ' ŋˌ ' : 106 , ' ɐˌ ' : 107 , ' —ˈ ' : 108 , ' iˌ ' : 109 , ' iː ˌ ' : 110 , ' ɛː ' : 111 , ' ) ' : 112 , ' )ˈ ' : 113 , ' ( ' : 114 , ' u ' : 115 , ' - ' : 116 , ' ɖˈ ' : 117 , ' iˈ ' : 118 , ' ʰˈ ' : 119 , ' ɟˈ ' : 120 , ' ̃ ' : 121 , ' eː ' : 122 , ' ɾˈ ' : 123 , ' r ' : 124 , ' ʰ ' : 125 , ' -ˌ ' : 126 , ' ɫ ' : 127 , ' q ' : 128 , ' — ' : 129 , ' ʊˌ ' : 130 , ' aː ' : 131 , ' cˈ ' : 132 , ' …ˈ ' : 133 , ' c ' : 134 , ' ɳ ' : 135 , ' ɐˈ ' : 136 , ' x ' : 137 , ' ʔˌ ' : 138 , ' .ˌ ' : 139 , ' ɑ ' : 140 , ' ?ˈ ' : 141 , ' ̩ˈ ' : 142 , ' " ˈ ' : 143 , ' ,ˈ ' : 144 , ' ŋˈ ' : 145 , ' əˌ ' : 146 , ' !ˈ ' : 147 , ' " ˌ ' : 148 , ' ?ˌ ' : 149 , ' ,ˌ ' : 150 , ' —ˌ ' : 151 , ' ̩ˌ ' : 152 , ' əˈ ' : 153 , ' !ˌ ' : 154 , ' ɬ ' : 155 , ' ʲ ' : 156 , ' ¡ ' : 157 , ' ɯ ' : 158 , ' qˌ ' : 159 , ' ʑ ' : 160 , ' ʑˈ ' : 161 , ' ¿ ' : 162 , ' ɑ ː ˈ ' : 163 , ' iː ː ' : 164 , ' ɛˈ ' : 165 , ' ¡ˈ ' : 166 , ' æˈ ' : 167 , ' ç ' : 168 , ' ɾˌ ' : 169 , ' ᵻˈ ' : 170 , ' xˈ ' : 171 , ' ɔːˈ ' : 172 , ' ; ' : 173 , ' ɬˌ ' : 174 , ' : ' : 175 , ' ʔ ˈ ' : 176 , ' ɑːˌ ' : 177 , ' ɬˈ ' : 178 }
def tokenize ( content , lang_marker = " en " ) :
split = content . split ( " " )
phones = [ f " <s> " ] + [ " " if not p else p for p in split ] + [ f " </s> " ]
2024-04-09 01:14:51 +00:00
return torch . tensor ( [ * map ( symmap . get , phones ) ] )
2023-09-06 23:58:35 +00:00
qnt = torch . load ( " data/qnt.pt " ) [ 0 ] . t ( ) [ : , : cfg . models . prom_levels ] . to ( device )
2023-10-13 23:22:21 +00:00
cfg . hyperparameters . gradient_accumulation_steps = 1
2023-09-06 23:58:35 +00:00
text_list = [
tokenize ( " ˈ a ɪ w ɪ l nˌ ɑ ː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m" ) . to ( device ) ,
]
proms_list = [
2023-09-11 19:13:42 +00:00
qnt [ : 75 * 3 , : ] . to ( device ) ,
2023-09-06 23:58:35 +00:00
]
resps_list = [
qnt . to ( device ) ,
]
text_list = text_list [ : 1 ]
proms_list = proms_list [ : 1 ]
resps_list = resps_list [ : 1 ]
2024-03-02 02:38:06 +00:00
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
2023-09-06 23:58:35 +00:00
kwargs = {
' n_tokens ' : 1024 ,
2023-12-21 00:45:58 +00:00
' d_model ' : 1024 , # 256, # 1024, # 1536
' n_heads ' : 16 , # 4, # 16, # 24
2023-09-07 14:14:03 +00:00
' n_layers ' : 12 , # 32
2024-03-02 01:20:10 +00:00
' n_experts ' : 1 ,
2024-04-09 01:14:51 +00:00
' l_padding ' : 8 ,
2023-12-21 00:45:58 +00:00
}
"""
kwargs = {
' n_tokens ' : 1024 ,
' d_model ' : 256 ,
' n_heads ' : 4 ,
' n_layers ' : 12 ,
2023-12-23 01:27:36 +00:00
' n_experts ' : 8 ,
2023-09-06 23:58:35 +00:00
}
2024-03-02 02:38:06 +00:00
"""
2023-09-06 23:58:35 +00:00
"""
try :
kwargs [ ' config ' ] = cfg . models . ar_nar
except Exception as e :
pass
"""
model = AR_NAR ( * * kwargs ) . to ( device )
2023-12-23 01:27:36 +00:00
steps = 500
2024-03-02 01:20:10 +00:00
optimizer = ml . Prodigy ( model . parameters ( ) , lr = 1.0 )
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
2024-04-09 01:14:51 +00:00
2023-09-07 01:33:16 +00:00
engine = Engine ( model = model , optimizer = optimizer )
2023-09-07 14:14:03 +00:00
2024-04-05 00:11:49 +00:00
# copy embeddings if requested
if cfg . models . _embeddings is not None :
embeddings_path = cfg . relpath / cfg . models . _embeddings
if embeddings_path . exists ( ) :
embeddings = torch . load ( embeddings_path , map_location = torch . device ( cfg . device ) )
if " module " in embeddings :
embeddings = embeddings [ " module " ]
frozen_params = set ( )
for k in list ( embeddings . keys ( ) ) :
if re . findall ( r ' _emb \ . ' , k ) :
frozen_params . add ( k )
else :
del embeddings [ k ]
engine . module . load_state_dict ( embeddings , strict = False )
for name , param in engine . module . named_parameters ( ) :
if name not in frozen_params :
continue
param . requires_grad_ ( False )
engine . _frozen_params . add ( param )
2024-04-09 01:14:51 +00:00
# if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
model . model = ml . replace_linear ( model . model )
2024-03-02 01:20:10 +00:00
2023-09-21 00:10:59 +00:00
torch . save ( {
' module ' : model . state_dict ( )
} , " ./data/test.pth " )
2023-09-07 14:14:03 +00:00
print ( f " AR+NAR parameter count: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) } " )
2024-04-09 01:14:51 +00:00
2023-09-09 01:30:54 +00:00
@torch.inference_mode ( )
2023-09-06 23:58:35 +00:00
def sample ( name , steps = 600 ) :
engine . eval ( )
2023-10-10 22:02:33 +00:00
resps_list = engine ( text_list , proms_list , max_steps = steps , sampling_temperature = 0.95 )
2023-09-06 23:58:35 +00:00
for i , o in enumerate ( resps_list ) :
_ = decode_to_file ( o , f " data/ar. { i } . { name } .wav " , device = device )
resps_list = [ r . unsqueeze ( - 1 ) for r in resps_list ]
resps_list = engine ( text_list , proms_list , resps_list = resps_list , sampling_temperature = 0.2 )
for i , o in enumerate ( resps_list ) :
_ = decode_to_file ( o , f " data/ar+nar. { i } . { name } .wav " , device = device )
2023-09-07 01:33:16 +00:00
unload_model ( )
2023-09-06 23:58:35 +00:00
def train ( ) :
engine . train ( )
2023-09-07 22:08:38 +00:00
t = trange ( steps )
2023-09-06 23:58:35 +00:00
for i in t :
stats = { " step " : i }
stats | = engine . traverse ( text_list = text_list , proms_list = proms_list , resps_list = resps_list )
2024-03-02 02:18:43 +00:00
stats | = { " grad_norm " : engine . get_global_grad_norm ( ) }
2023-09-06 23:58:35 +00:00
tqdm . write ( f " { stats } " )
2023-12-23 01:27:36 +00:00
torch . save ( {
' module ' : model . state_dict ( )
} , " ./data/test.pth " )
2023-10-09 18:01:40 +00:00
sample ( " init " , 5 )
2023-09-06 23:58:35 +00:00
train ( )
sample ( " final " )
if __name__ == " __main__ " :
example_usage ( )