2024-06-09 01:30:15 +00:00
"""
# an AR + NAR model that handles:
* inferencing the primary RVQ level in an autoregressive manner ( AR )
* inferencing the remaining RVQ levels in parallel ( NAR )
This model can fully handle being trained as a unified model ( AR + NAR ) or separate models ( AR | NAR ) .
It ' s recommended to train as a unified model, then " distill " knowledge of each tasks separately, just in case.
"""
2023-09-06 23:58:35 +00:00
from . base import Base , list_to_tensor , Categorical
2023-09-07 01:33:16 +00:00
from . . config import cfg
2023-09-06 23:58:35 +00:00
import torch
from torch . nn . utils . rnn import pad_sequence
import random
2023-09-13 18:19:11 +00:00
import math
2023-09-06 23:58:35 +00:00
from einops import rearrange
from torch import Tensor
from tqdm import trange
2024-10-05 03:18:20 +00:00
from time import perf_counter
2024-08-29 18:27:16 +00:00
import logging
_logger = logging . getLogger ( __name__ )
2023-09-06 23:58:35 +00:00
2024-10-18 22:19:52 +00:00
from . . emb . qnt import trim , encode_as_embedding , get_silence
2024-10-05 03:18:20 +00:00
from . . utils import get_devices , setup_logging , timer
2023-10-10 03:03:58 +00:00
2024-06-18 02:45:03 +00:00
from . lora import enable_lora
2024-07-25 00:35:17 +00:00
def clamp ( n , lo , hi ) :
return max ( lo , min ( n , hi ) )
2023-09-06 23:58:35 +00:00
class AR_NAR ( Base ) :
def forward (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
2024-07-18 21:16:14 +00:00
task_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
lang_list : list [ Tensor ] | None = None ,
2024-04-16 00:54:32 +00:00
tone_list : list [ Tensor ] | None = None ,
2024-06-08 20:42:02 +00:00
len_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
2024-06-29 15:16:37 +00:00
training : bool | None = None ,
2023-09-06 23:58:35 +00:00
max_steps : int = 1000 ,
2024-01-25 18:18:48 +00:00
max_levels : int = 0 ,
2023-10-12 01:38:40 +00:00
2024-10-04 23:57:19 +00:00
input_prompt_prefix : bool = False ,
2024-10-18 22:19:52 +00:00
prefix_silence : float = 1.0 ,
2024-10-04 23:57:19 +00:00
2023-10-10 22:02:33 +00:00
sampling_temperature : float = 1.0 ,
sampling_min_temperature : float = - 1.0 ,
2023-09-09 01:30:54 +00:00
sampling_top_k : int = - 100 ,
sampling_top_p : float = 1.0 ,
2024-10-12 03:36:06 +00:00
sampling_min_p : float = 0.0 ,
2023-09-09 01:30:54 +00:00
sampling_repetition_penalty : float = 1.0 ,
2023-09-09 02:02:00 +00:00
sampling_repetition_penalty_decay : float = 0.0 ,
2023-09-09 01:30:54 +00:00
sampling_length_penalty : float = 0.0 ,
2023-09-13 02:28:07 +00:00
sampling_beam_width : int = 0 ,
2023-09-18 23:55:41 +00:00
sampling_mirostat_tau : float = 0.0 ,
sampling_mirostat_eta : float = 0.1 ,
2024-07-30 00:15:07 +00:00
sampling_dry_multiplier = 0.0 ,
sampling_dry_base = 1.75 ,
sampling_dry_allowed_length = 2 ,
2024-10-12 16:27:55 +00:00
sampling_entropix = False ,
2024-10-12 02:18:26 +00:00
2024-07-20 01:49:40 +00:00
disable_tqdm = False ,
2024-10-11 00:04:12 +00:00
use_lora = None ,
2023-09-06 23:58:35 +00:00
) :
2024-09-06 01:43:20 +00:00
text_task = [ " stt " ]
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
2024-07-18 21:16:14 +00:00
# generate task list if not provided
if task_list is None :
2024-09-06 01:43:20 +00:00
task_list = [ default_task for _ in range ( batch_size ) ]
2023-09-06 23:58:35 +00:00
2024-09-06 19:30:12 +00:00
has_none = resps_list is None or text_list is None
if not has_none :
for i , task in enumerate ( task_list ) :
if resps_list [ i ] is None or text_list [ i ] is None :
has_none = True
break
2023-09-06 23:58:35 +00:00
# is training or NAR
2024-09-06 19:30:12 +00:00
if not has_none :
2023-09-06 23:58:35 +00:00
n_levels_set = { r . shape [ - 1 ] for r in resps_list }
n_levels = next ( iter ( n_levels_set ) )
2024-06-29 15:16:37 +00:00
if training is None :
training = n_levels == self . n_resp_levels
2023-09-06 23:58:35 +00:00
# is training
2024-06-29 15:16:37 +00:00
if training :
2024-07-27 20:36:05 +00:00
# specifies how to sample probabilities of which RVQ levels to train against
2024-10-17 22:06:48 +00:00
rvq_levels_p = self . config . experimental . rvq_levels_p if self . config is not None else " equal "
2024-06-08 20:42:02 +00:00
# determines which RVQ level to target per batch
2024-07-27 20:36:05 +00:00
quant_level_range = self . config . experimental . rvq_level_range if self . config is not None and self . config . experimental . rvq_level_range else [ 0 if self . causal else 1 , self . n_resp_levels - 1 ]
# rate to perform token dropout errors
2024-07-25 00:35:17 +00:00
token_dropout_error = self . config . experimental . token_dropout_error
2024-07-27 20:36:05 +00:00
# RVQ levels to apply token dropout on
2024-07-25 00:35:17 +00:00
token_dropout_rvq_levels = self . config . experimental . token_dropout_rvq_levels
2024-07-27 20:36:05 +00:00
# implicitly set it to all levels
2024-07-25 00:35:17 +00:00
if not token_dropout_rvq_levels :
2024-07-27 20:36:05 +00:00
token_dropout_rvq_levels = [ 0 , self . resp_levels - 1 ]
# allow passing a specific distribution of RVQ levels
2024-10-17 22:06:48 +00:00
rvq_levels_p = rvq_levels_p if isinstance ( rvq_levels_p , list ) else [ ]
if not rvq_levels_p :
2024-07-27 20:36:05 +00:00
lo , hi = quant_level_range [ 0 ] , quant_level_range [ 1 ] + 1
2024-06-29 03:28:54 +00:00
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
2024-10-17 22:06:48 +00:00
if rvq_levels_p == " equal " :
rvq_levels_p = [ i for i in range ( lo , hi ) ]
2024-07-27 20:36:05 +00:00
else :
# yuck
2024-10-17 22:06:48 +00:00
rvq_levels_p = sum ( [ [ i for _ in range ( hi - i ) ] for i in range ( lo , hi ) ] , [ ] )
2024-07-27 20:36:05 +00:00
# input RVQ levels
2024-10-17 22:06:48 +00:00
quant_levels = [ random . choice ( rvq_levels_p ) for i in range ( batch_size ) ]
2024-09-06 01:43:20 +00:00
for i , task in enumerate ( task_list ) :
if task in text_task :
quant_levels [ i ] = 0 # self.n_resp_levels - 1
2024-07-27 20:36:05 +00:00
# trim resps to only contain all levels below the target level
2024-09-06 01:43:20 +00:00
resps_list = [ r if t in text_task else r [ . . . , : l + 1 ] for r , l , t in zip ( resps_list , quant_levels , task_list ) ]
2024-07-27 20:36:05 +00:00
# tensor to cat for RVQ level 0
2024-09-06 01:43:20 +00:00
text_stop_sequence = torch . tensor ( [ [ 2 ] * 1 ] , device = device , dtype = torch . int16 )
audio_stop_sequence = torch . tensor ( [ [ self . stop_token ] * 1 ] , device = device , dtype = torch . int16 )
2024-07-27 20:36:05 +00:00
# I hate python's value/reference semantics so much
2024-09-06 01:43:20 +00:00
for i , quant_level , resps , proms , task in zip ( range ( batch_size ) , quant_levels , resps_list , proms_list , task_list ) :
2024-07-23 00:36:07 +00:00
# cap quant_level if it exceeds its corresponding resp/prom
2024-07-25 00:35:17 +00:00
if quant_level > = resps . shape [ - 1 ] :
quant_levels [ i ] = resps . shape [ - 1 ] - 1
2024-06-30 00:46:11 +00:00
2024-07-25 00:35:17 +00:00
# proms could be a Tensor, list[Tensor], or None
if isinstance ( proms , torch . Tensor ) :
if quant_level > = proms . shape [ - 1 ] :
quant_levels [ i ] = proms . shape [ - 1 ] - 1
2024-06-30 00:46:11 +00:00
2024-07-25 00:35:17 +00:00
elif isinstance ( proms , list ) :
for j , prom in enumerate ( proms ) :
2024-07-23 00:36:07 +00:00
if not isinstance ( prom , torch . Tensor ) :
continue
2024-09-06 16:41:41 +00:00
if quant_level > = prom . shape [ - 1 ] :
quant_levels [ i ] = prom . shape [ - 1 ] - 1
2024-07-23 00:36:07 +00:00
2024-07-25 00:35:17 +00:00
# apply token dropout error compensation
if token_dropout_error > 0 and ( token_dropout_rvq_levels [ 0 ] < = quant_level and quant_level < = token_dropout_rvq_levels [ 1 ] ) :
steps = resps . shape [ 0 ]
for l in range ( quant_level ) :
for t in range ( steps ) :
token = resps [ t , l ] . item ( )
2024-06-30 00:46:11 +00:00
2024-07-25 00:35:17 +00:00
if random . random ( ) < token_dropout_error :
offset = 1 * ( 1 if random . random ( ) < 0.5 else - 1 )
resps_list [ i ] [ t , l ] = clamp ( token + offset , 1 , 1022 ) # +- 1
# only apply stop token for RVQ level 0
if quant_level < = 0 :
# append stop tokens for AR
2024-09-06 01:43:20 +00:00
if task in text_task :
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
. . .
else :
resps_list [ i ] = torch . cat ( [ resps , audio_stop_sequence ] )
2024-06-30 00:46:11 +00:00
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
2023-09-06 23:58:35 +00:00
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
2023-10-12 02:21:50 +00:00
lang_list = lang_list ,
2024-06-05 04:23:31 +00:00
tone_list = tone_list ,
2024-06-08 20:42:02 +00:00
task_list = task_list ,
2024-06-05 04:23:31 +00:00
quant_levels = quant_levels ,
2024-04-17 02:04:48 +00:00
)
return super ( ) . forward (
inputs = inputs ,
2024-07-18 19:18:34 +00:00
quant_levels = quant_levels , # could technically just grab this from the above inputs since they're included as an RVQ level token
2023-09-06 23:58:35 +00:00
)
2024-06-08 20:42:02 +00:00
2023-09-06 23:58:35 +00:00
# is NAR
2023-09-11 01:33:33 +00:00
if max_levels == 0 :
2024-07-16 00:59:48 +00:00
max_levels = self . n_max_levels - 1
2024-06-06 00:50:06 +00:00
# expand if given a raw 1D tensor
for i , resp in enumerate ( resps_list ) :
if resp . dim ( ) == 1 :
resps_list [ i ] = resp . unsqueeze ( - 1 )
2023-10-17 00:30:38 +00:00
prev_list = resps_list
2023-09-06 23:58:35 +00:00
2024-07-20 01:49:40 +00:00
for n in trange ( max_levels , desc = " NAR " , disable = disable_tqdm ) :
2023-09-08 20:36:26 +00:00
level = prev_list [ 0 ] . shape [ - 1 ]
2023-09-10 18:50:13 +00:00
if level > = max_levels + 1 : # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
2023-09-06 23:58:35 +00:00
break
2024-06-18 02:45:03 +00:00
if cfg . lora is not None :
2024-10-11 00:04:12 +00:00
enable_lora ( self , cfg . lora . active_level ( level ) if use_lora is None else use_lora )
2024-06-18 02:45:03 +00:00
2024-06-08 20:42:02 +00:00
quant_levels = [ level for _ in range ( batch_size ) ] # torch.full((len(text_list),), level)
2023-09-06 23:58:35 +00:00
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
2023-09-13 02:28:07 +00:00
text_list = text_list ,
proms_list = proms_list ,
resps_list = prev_list ,
2023-10-12 02:21:50 +00:00
lang_list = lang_list ,
2024-04-16 00:54:32 +00:00
tone_list = tone_list ,
2024-06-05 04:23:31 +00:00
quant_levels = quant_levels ,
2024-04-17 02:04:48 +00:00
)
2024-10-06 03:53:53 +00:00
output = super ( ) . forward (
2024-04-17 02:04:48 +00:00
inputs = inputs ,
2023-09-13 02:28:07 +00:00
quant_levels = quant_levels ,
)
2024-10-12 02:18:26 +00:00
logits , state = output . logits , output . state
2023-09-13 02:28:07 +00:00
2024-10-12 02:18:26 +00:00
sampled = super ( ) . sample (
2023-09-13 02:28:07 +00:00
logits = logits ,
2024-09-06 01:43:20 +00:00
prev_list = prev_list ,
2023-09-06 23:58:35 +00:00
quant_levels = quant_levels ,
2023-09-13 02:28:07 +00:00
temperature = sampling_temperature ,
2023-10-10 22:02:33 +00:00
min_temperature = sampling_min_temperature ,
2023-09-13 02:28:07 +00:00
top_p = sampling_top_p ,
top_k = sampling_top_k ,
2024-10-12 03:36:06 +00:00
min_p = sampling_min_p ,
2024-07-30 00:15:07 +00:00
#repetition_penalty=sampling_repetition_penalty,
#repetition_penalty_decay=sampling_repetition_penalty_decay,
2023-09-13 02:28:07 +00:00
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
2023-09-18 23:55:41 +00:00
#mirostat=mirostat,
2023-09-06 23:58:35 +00:00
)
2024-10-12 02:18:26 +00:00
resps_list = sampled [ 0 ]
2024-08-04 03:10:21 +00:00
prev_list = [ torch . cat ( [ rs , r . unsqueeze ( - 1 ) . to ( device = device ) ] , dim = - 1 ) for rs , r in zip ( prev_list , resps_list ) ]
2023-09-06 23:58:35 +00:00
return prev_list
2024-06-08 20:42:02 +00:00
2023-09-06 23:58:35 +00:00
# is AR
2024-06-18 02:45:03 +00:00
if cfg . lora is not None :
2024-10-11 00:04:12 +00:00
enable_lora ( self , cfg . lora . active_level ( 0 ) if use_lora is None else use_lora )
2024-06-18 02:45:03 +00:00
2024-09-06 01:43:20 +00:00
# STT
2024-10-04 23:57:19 +00:00
start_slice = [ 0 for _ in range ( batch_size ) ]
2024-06-08 20:42:02 +00:00
sequence_list = [ torch . zeros ( 0 , device = device ) . to ( torch . int16 ) for _ in range ( batch_size ) ]
2023-09-06 23:58:35 +00:00
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-06-08 20:42:02 +00:00
2024-09-06 19:30:12 +00:00
audio_stop_token = self . stop_token
text_stop_token = 2
2023-09-06 23:58:35 +00:00
2024-06-29 03:44:00 +00:00
state = None
2023-09-18 23:55:41 +00:00
mirostat = [
{ " n " : 1024 , " tau " : sampling_mirostat_tau , " eta " : sampling_mirostat_eta , " max_surprise " : sampling_mirostat_eta * 2 , " error_surprise " : 0 , " running_total_surprise " : 0 }
] * batch_size if sampling_mirostat_tau > 0.0 else None
2023-09-06 23:58:35 +00:00
2023-09-13 18:19:11 +00:00
scores = [ 1.0 ] * sampling_beam_width
2024-10-12 02:18:26 +00:00
entropies = [ ]
2024-10-18 21:55:00 +00:00
# ick
low_temperature = sampling_repetition_penalty == 1.0 and sampling_temperature < 0.5
low_temperature_range = cfg . dataset . frames_per_second * 3
original_sampling_temperature = sampling_temperature
original_sampling_repetition_penalty = sampling_repetition_penalty
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
2024-09-06 01:43:20 +00:00
for i , sequence in enumerate ( sequence_list ) :
2024-10-04 23:57:19 +00:00
# add <bos> to text for STT
2024-09-06 01:43:20 +00:00
if task_list [ i ] in text_task :
2024-10-04 23:57:19 +00:00
start_slice [ i ] = 1
2024-09-06 01:43:20 +00:00
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , torch . tensor ( [ 1 ] , dtype = torch . int16 , device = device ) ] )
2024-10-04 23:57:19 +00:00
# treat input prompt as initial resp (by prefixing with the prompt instead)
elif input_prompt_prefix :
start_slice [ i ] = proms_list [ i ] . shape [ 0 ]
sequence_list [ i ] , proms_list [ i ] = proms_list [ i ] [ : , 0 ] , sequence_list [ i ]
2024-10-18 22:19:52 +00:00
elif prefix_silence > 0 :
sequence_list [ i ] = get_silence ( prefix_silence , device = sequence_list [ i ] . device )
sequence_list [ i ] = sequence_list [ i ] [ : , 0 ]
# start_slice[i] = sequence_list[i].shape[0]
2024-09-06 01:43:20 +00:00
2023-09-13 18:19:11 +00:00
# get next in sequence
2024-07-20 01:49:40 +00:00
for n in trange ( max_steps / / max ( 1 , self . causal_size ) , desc = " AR " , disable = disable_tqdm ) :
2024-10-05 03:30:47 +00:00
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
2024-09-06 19:30:12 +00:00
text_list = [ sequence_list [ i ] if task in text_task else text_list [ i ] for i , task in enumerate ( task_list ) ]
resps_list = [ sequence_list [ i ] if task not in text_task else resps_list [ i ] for i , task in enumerate ( task_list ) ]
2024-10-18 21:55:00 +00:00
# greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
# to-do: tune these values, maybe have it factor based on confidence scores or something
if low_temperature :
enabled = n < low_temperature_range
sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty
sampling_repetition_penalty_decay = 0.5 if enabled else original_sampling_repetition_penalty_decay
sampling_temperature = original_sampling_temperature if enabled else 1.0
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
2024-06-08 20:42:02 +00:00
len_list = len_list ,
task_list = task_list ,
2024-06-08 01:46:22 +00:00
quant_levels = [ 0 for _ in range ( max ( batch_size , sampling_beam_width ) ) ]
2024-04-17 02:04:48 +00:00
)
2024-10-05 03:30:47 +00:00
# to-do: find an elegant way to write this
2024-10-06 03:53:53 +00:00
output = super ( ) . forward (
inputs = inputs ,
state = state ,
2024-10-12 02:18:26 +00:00
output_attentions = sampling_entropix ,
2024-10-06 03:53:53 +00:00
)
2024-10-12 02:18:26 +00:00
logits , state = output . logits , output . state
2023-09-06 23:58:35 +00:00
2024-10-12 02:18:26 +00:00
sampled = super ( ) . sample (
2023-09-13 02:28:07 +00:00
logits = logits ,
2024-10-05 03:18:20 +00:00
prev_list = None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list [ i ] if task not in text_task else text_list [ i ] for i , task in enumerate ( task_list ) ] ,
2023-09-13 02:28:07 +00:00
temperature = sampling_temperature ,
2023-10-10 22:02:33 +00:00
min_temperature = sampling_min_temperature ,
2023-09-13 02:28:07 +00:00
top_p = sampling_top_p ,
top_k = sampling_top_k ,
2024-10-12 03:36:06 +00:00
min_p = sampling_min_p ,
2023-09-13 02:28:07 +00:00
repetition_penalty = sampling_repetition_penalty ,
repetition_penalty_decay = sampling_repetition_penalty_decay ,
length_penalty = sampling_length_penalty ,
beam_width = sampling_beam_width ,
2023-09-18 23:55:41 +00:00
mirostat = mirostat ,
2024-07-30 00:15:07 +00:00
dry_multiplier = sampling_dry_multiplier ,
dry_base = sampling_dry_base ,
dry_allowed_length = sampling_dry_allowed_length ,
2024-10-12 02:18:26 +00:00
attentions = output . attentions if sampling_entropix else None ,
2023-09-13 02:28:07 +00:00
)
2024-10-12 02:18:26 +00:00
r = sampled [ 0 ]
if sampled . entropy :
entropies . append ( sampled . entropy )
2023-09-18 23:55:41 +00:00
if mirostat is not None :
2024-10-12 02:18:26 +00:00
mirostat = sampled . scores
2023-09-18 23:55:41 +00:00
elif sampling_beam_width > 0 :
2023-09-13 18:19:11 +00:00
# expand tuple
2024-10-12 02:18:26 +00:00
scores = sampled . scores
2023-09-13 18:19:11 +00:00
# first step, expand batch
if batch_size == 1 :
2023-10-17 00:30:38 +00:00
batch_size = sampling_beam_width
2023-09-13 18:19:11 +00:00
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-10-12 02:18:26 +00:00
scores = [ scores [ i ] + score for i , score in enumerate ( scores ) ]
2023-09-13 02:28:07 +00:00
2023-09-06 23:58:35 +00:00
# append tokens
for i , ri in enumerate ( r ) :
2024-09-06 19:30:12 +00:00
task = task_list [ i ]
stop_token = audio_stop_token if task not in text_task else text_stop_token
2024-06-08 20:42:02 +00:00
if stop_token in ri :
2023-09-06 23:58:35 +00:00
stopped [ i ] = True
2023-10-11 17:25:31 +00:00
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , ri . to ( device ) ] )
2023-09-06 23:58:35 +00:00
# stop token found
2024-09-06 19:30:12 +00:00
# stopped |= r == stop_token
2023-09-06 23:58:35 +00:00
if stopped . all ( ) . item ( ) :
break
2024-10-12 02:18:26 +00:00
if entropies :
from . . plot import plot_entropies
plot_entropies ( entropies )
2023-09-13 18:19:11 +00:00
# pick the best scoring candidate
# desu this is always going to be candidate 0
2023-10-17 00:30:38 +00:00
if sampling_beam_width :
sequence_list = [ sequence_list [ 0 ] ]
2023-09-13 02:28:07 +00:00
2024-09-06 19:30:12 +00:00
# remove stop token
sequence_list = [ self . _prune ( r , audio_stop_token if task_list [ i ] not in text_task else text_stop_token ) for i , r in enumerate ( sequence_list ) ]
# remove <bos>
2024-10-04 23:57:19 +00:00
sequence_list = [ sequence_list [ i ] [ start_slice [ i ] : ] for i , task in enumerate ( task_list ) ]
2024-10-12 15:41:35 +00:00
2024-06-08 20:42:02 +00:00
return sequence_list
2023-09-06 23:58:35 +00:00
def example_usage ( ) :
2024-08-10 02:15:01 +00:00
cfg . trainer . backend = " local "
2024-04-21 19:58:04 +00:00
cfg . hyperparameters . gradient_accumulation_steps = 1
2024-05-25 16:07:52 +00:00
if cfg . audio_backend == " dac " :
2024-07-04 20:58:08 +00:00
cfg . sample_rate = 44_100
2023-09-06 23:58:35 +00:00
2024-04-21 19:58:04 +00:00
from functools import partial
2023-09-06 23:58:35 +00:00
from einops import repeat
2024-04-21 19:58:04 +00:00
from tqdm import tqdm
2023-09-06 23:58:35 +00:00
2024-07-18 23:46:45 +00:00
from . . emb . qnt import decode_to_file , unload_model , trim_random , repeat_extend_audio , concat_audio , merge_audio
2024-07-25 21:50:47 +00:00
from . . engines import Engine , Engines
2023-09-07 22:08:38 +00:00
from . . utils import wrapper as ml
2024-09-06 04:21:18 +00:00
from . . utils import setup_logging
2024-04-21 19:58:04 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-09-06 23:58:35 +00:00
2024-09-06 04:21:18 +00:00
setup_logging ( )
2023-09-06 23:58:35 +00:00
device = " cuda "
2024-09-06 04:21:18 +00:00
2024-06-05 03:41:22 +00:00
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
2024-06-05 04:23:31 +00:00
"""
2024-06-05 03:41:22 +00:00
if " mamba " in cfg . model . arch_type :
cfg . model . resp_levels = 1
2024-06-05 04:23:31 +00:00
"""
2024-06-05 04:48:51 +00:00
# cfg.model.loss_factors = {}
2023-09-06 23:58:35 +00:00
2024-04-21 19:58:04 +00:00
def tokenize ( content ) :
return torch . tensor ( cfg . tokenizer . encode ( content ) )
2024-04-18 01:39:35 +00:00
2024-04-21 19:58:04 +00:00
def _load_quants ( path ) - > Tensor :
2024-06-04 01:26:27 +00:00
qnt = np . load ( path , allow_pickle = True ) [ ( ) ]
2024-07-16 00:59:48 +00:00
return torch . from_numpy ( qnt [ " codes " ] . astype ( np . int16 ) ) [ 0 , : cfg . model . resp_levels , : ] . t ( ) . to ( torch . int16 )
2024-04-21 19:58:04 +00:00
2024-06-04 01:26:27 +00:00
qnt = _load_quants ( f " ./data/qnt. { ' dac ' if cfg . audio_backend == ' dac ' else ' enc ' } " )
2024-07-18 23:46:45 +00:00
noise = _load_quants ( f " ./data/noise. { ' dac ' if cfg . audio_backend == ' dac ' else ' enc ' } " )
2023-10-13 23:22:21 +00:00
2023-09-06 23:58:35 +00:00
text_list = [
2024-04-21 19:58:04 +00:00
tokenize ( " ˈ aɪ wɪ l nˌɑ ː t ˈ æsk ɐ sˈ ɛkənd tˈ aɪ m" ) . to ( device ) ,
2024-05-19 16:23:56 +00:00
#tokenize("ˈ aɪ wɪ l nˌɑ ː t ˈ æsk").to(device),
2023-09-06 23:58:35 +00:00
]
proms_list = [
2024-05-04 17:05:41 +00:00
qnt [ : cfg . dataset . frames_per_second , : ] . to ( device ) ,
2024-05-19 16:23:56 +00:00
#qnt[:cfg.dataset.frames_per_second, :].to(device),
2023-09-06 23:58:35 +00:00
]
resps_list = [
2024-05-19 16:23:56 +00:00
qnt [ : , : ] . to ( device ) ,
#qnt[:cfg.dataset.frames_per_second, :].to(device),
2023-09-06 23:58:35 +00:00
]
text_list = text_list [ : 1 ]
proms_list = proms_list [ : 1 ]
resps_list = resps_list [ : 1 ]
2024-07-18 23:46:45 +00:00
batch_size = len ( text_list )
2024-03-02 02:38:06 +00:00
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
2023-09-06 23:58:35 +00:00
kwargs = {
2024-06-06 00:50:06 +00:00
' n_text_tokens ' : 256 ,
' n_audio_tokens ' : 1024 ,
2024-08-05 03:03:22 +00:00
' d_model ' : 1024 , # 256, # 1024, # 1536
2023-12-21 00:45:58 +00:00
' n_heads ' : 16 , # 4, # 16, # 24
2024-06-05 03:41:22 +00:00
' n_layers ' : 12 , # 32
2024-08-05 01:25:06 +00:00
' n_experts ' : 1 if not cfg . model else cfg . model . experts ,
2024-04-09 01:14:51 +00:00
2024-05-19 21:40:14 +00:00
' p_dropout ' : 0.1 ,
2024-05-11 22:14:05 +00:00
2024-05-03 01:08:59 +00:00
' l_padding ' : 8 if cfg . optimizations . fp8 else 0 ,
2024-04-30 03:14:01 +00:00
' config ' : cfg . model
2023-12-21 00:45:58 +00:00
}
2023-09-06 23:58:35 +00:00
"""
try :
2024-04-16 00:54:32 +00:00
kwargs [ ' config ' ] = cfg . model
2023-09-06 23:58:35 +00:00
except Exception as e :
pass
"""
2024-07-18 23:46:45 +00:00
bos_id , space_id , eos_id = cfg . tokenizer . encode ( " " )
2024-09-06 19:30:12 +00:00
#available_tasks = cfg.dataset.tasks_list
available_tasks = [ " tts " , " stt " ]
2024-07-18 23:46:45 +00:00
2023-09-06 23:58:35 +00:00
model = AR_NAR ( * * kwargs ) . to ( device )
2024-09-06 01:43:20 +00:00
steps = 150 * len ( available_tasks ) # * cfg.model.experimental.causal_size
2024-04-09 01:14:51 +00:00
2024-06-09 16:22:52 +00:00
optimizer = cfg . hyperparameters . optimizer . lower ( ) if cfg . yaml_path is not None else " prodigy "
scheduler = cfg . hyperparameters . scheduler . lower ( ) if cfg . yaml_path is not None else " "
learning_rate = cfg . hyperparameters . learning_rate if cfg . yaml_path is not None else None
2024-05-10 01:28:20 +00:00
if cfg . optimizations . dadaptation :
# do not combine the two
if scheduler == " schedulefree " :
scheduler = " "
2023-09-07 14:14:03 +00:00
2024-05-10 01:28:20 +00:00
learning_rate = 1.0
if optimizer == " prodigy " :
if learning_rate is None :
learning_rate = 1.0
optimizer = ml . Prodigy
elif optimizer == " adagrad " :
if learning_rate is None :
learning_rate = 1.0e-2
optimizer = ml . Adagrad
elif optimizer == " adamw " :
if learning_rate is None :
learning_rate = 1.0e-4
optimizer = ml . AdamW
elif optimizer == " sdg " :
if learning_rate is None :
learning_rate = 1.0e-4
optimizer = ml . SGD
else :
raise ValueError ( f " Unrecognized optimizer: { optimizer } " )
2024-08-29 18:27:16 +00:00
_logger . info ( f " Optimizer: { optimizer } \t Learning rate: { learning_rate } " )
2024-05-10 01:28:20 +00:00
optimizer = optimizer ( model . parameters ( ) , lr = learning_rate )
if scheduler == " schedulefree " :
if isinstance ( optimizer , ml . AdamW ) :
scheduler = ml . schedulefree . AdamWScheduleFree
elif isinstance ( optimizer , ml . SGD ) :
scheduler = ml . schedulefree . SGDScheduleFree
else :
scheduler = None
if scheduler is not None :
2024-08-29 18:27:16 +00:00
_logger . info ( f " Scheduler: { scheduler } " )
2024-05-10 01:28:20 +00:00
optimizer = scheduler ( model . parameters ( ) , lr = learning_rate )
if cfg . optimizations . replace and cfg . optimizations . linear :
model = ml . replace_linear ( model )
if cfg . optimizations . replace and cfg . optimizations . embedding :
model = ml . replace_embedding ( model )
2024-08-02 01:12:06 +00:00
"""
cfg . optimizations . model_offloading = {
" devices " : [ " cuda:0 " , " cpu " ] ,
2024-08-02 01:56:28 +00:00
# "limits": [ 0.9, -1 ],
2024-08-02 03:43:39 +00:00
" assign " : [ [ f ' layers. { i } . ' for i in range ( 0 , 10 ) ] , [ f ' layers. { i } . ' for i in range ( 11 , 12 ) ] + [ " model.norm " ] ] ,
# "limits": [ 256 * (1024 ** 2), -1 ]
2024-08-02 01:12:06 +00:00
}
"""
2024-05-10 01:28:20 +00:00
engine = Engine ( model = model , optimizer = optimizer )
2024-07-25 21:50:47 +00:00
engines = Engines ( { " ar+nar " : engine } )
engines . setup ( )
2024-08-02 01:56:28 +00:00
2024-08-03 03:25:49 +00:00
"""
2024-08-02 01:56:28 +00:00
if cfg . optimizations . model_offloading :
model = ml . offload_model ( model , policy = cfg . optimizations . model_offloading )
2024-08-03 03:25:49 +00:00
"""
2024-07-25 21:50:47 +00:00
2024-06-07 01:51:31 +00:00
"""
2023-09-21 00:10:59 +00:00
torch . save ( {
' module ' : model . state_dict ( )
2024-06-04 05:07:00 +00:00
} , f " ./data/ { cfg . model . arch_type } .pth " )
2024-06-07 01:51:31 +00:00
"""
2023-09-21 00:10:59 +00:00
2024-08-29 18:27:16 +00:00
_logger . info ( f " AR+NAR ( { cfg . model . arch_type } , { cfg . audio_backend } ) parameter count: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) } " )
2024-04-09 01:14:51 +00:00
2024-07-18 23:46:45 +00:00
@torch.no_grad ( )
2024-09-06 01:43:20 +00:00
def sample_data ( t = None ) :
2024-09-06 19:30:12 +00:00
if isinstance ( t , list ) :
tasks = t
texts = [ text_list [ 0 ] . to ( device ) if task != " stt " else None for i , task in enumerate ( tasks ) ]
proms = [ proms_list [ 0 ] . to ( device ) if task != " stt " else [ " stt " ] for i , task in enumerate ( tasks ) ]
resps = [ None if task != " stt " else resps_list [ 0 ] . to ( device ) for i , task in enumerate ( tasks ) ]
return texts , proms , resps , tasks
2024-07-18 23:46:45 +00:00
texts = [ ]
proms = [ ]
resps = [ ]
2024-09-06 01:43:20 +00:00
tasks = [ ]
2024-07-18 23:46:45 +00:00
for i in range ( batch_size ) :
2024-09-06 01:43:20 +00:00
task = random . choice ( available_tasks ) if t is None else t
2024-07-18 23:46:45 +00:00
2024-09-06 19:30:12 +00:00
text = text_list [ i ] . to ( device )
prom = proms_list [ i ] . to ( device )
resp = resps_list [ i ] . to ( device )
2024-07-18 23:46:45 +00:00
# do nothing
if task == " tts " :
. . .
2024-09-06 01:43:20 +00:00
elif task == " stt " :
2024-09-06 19:30:12 +00:00
prom = [
task
]
# to-do: reimplement this from data.py
"""
2024-07-18 23:46:45 +00:00
elif task == " tts-c " :
trim_length = int ( random . uniform ( cfg . dataset . prompt_duration_range [ 0 ] , cfg . dataset . prompt_duration_range [ 1 ] ) * cfg . dataset . frames_per_second )
prom = resp [ : trim_length ]
resp = resp [ trim_length : ]
2024-09-06 19:30:12 +00:00
prom = prom . to ( device )
2024-07-18 23:46:45 +00:00
elif task == " ns " or task == " sr " :
# extend the noise to fill the target audio
noise_ext = repeat_extend_audio ( noise , resp . shape [ 0 ] )
# create the input prompt by merging the target audio with the noise
prom = merge_audio ( resp . cpu ( ) , noise_ext , scale = [ 1 , cfg . dataset . noise_scale ] , device = cfg . dataset . reencode_device )
2024-09-06 19:30:12 +00:00
prom = prom . to ( device )
2024-07-18 23:46:45 +00:00
# set the target to just be the noise if <sr>
if task == " sr " :
resp = noise_ext
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
2024-08-04 03:10:21 +00:00
text = torch . tensor ( [ bos_id , eos_id ] , device = device , dtype = torch . uint8 )
2024-07-18 23:46:45 +00:00
2024-09-06 19:30:12 +00:00
prom = [
task ,
prom ,
]
"""
texts . append ( text )
proms . append ( prom )
resps . append ( resp )
2024-09-06 01:43:20 +00:00
tasks . append ( task )
2024-07-18 23:46:45 +00:00
2024-09-06 01:43:20 +00:00
return texts , proms , resps , tasks
2024-04-18 01:39:35 +00:00
2024-07-18 23:46:45 +00:00
@torch.inference_mode ( )
2024-09-06 01:43:20 +00:00
def sample ( name , steps = 500 , task = None ) :
2023-09-06 23:58:35 +00:00
engine . eval ( )
2024-07-18 23:46:45 +00:00
2024-09-06 01:43:20 +00:00
texts , proms , resps , tasks = sample_data ( task )
2024-07-18 23:46:45 +00:00
2024-09-06 19:30:12 +00:00
if " ar " in cfg . model . capabilities :
output = engine ( texts , proms , resps , task_list = tasks , max_steps = steps , sampling_temperature = 0.95 )
2024-06-05 04:48:51 +00:00
2024-09-06 19:30:12 +00:00
text = [ cfg . tokenizer . decode ( output [ i ] ) for i , task in enumerate ( tasks ) if task == " stt " ]
texts = [ texts [ i ] for i , task in enumerate ( tasks ) if task != " stt " ]
proms = [ proms [ i ] for i , task in enumerate ( tasks ) if task != " stt " ]
resps = [ output [ i ] for i , task in enumerate ( tasks ) if task != " stt " ]
tasks = [ tasks [ i ] for i , task in enumerate ( tasks ) if task != " stt " ]
print ( " STT: " , text )
2024-09-06 01:43:20 +00:00
else :
2024-09-06 19:30:12 +00:00
resps = [ resp [ : , 0 ] for resp in resps ]
if " nar " in cfg . model . capabilities :
resps = engine ( texts , proms , resps , task_list = tasks , sampling_temperature = 0.2 )
2023-09-06 23:58:35 +00:00
2024-09-06 19:30:12 +00:00
for i , o in enumerate ( resps ) :
_ = decode_to_file ( o . to ( dtype = torch . int32 ) , f " data/ { cfg . model . arch_type } . { cfg . audio_backend } . { i } . { task } . { name } .wav " , device = device )
2023-09-06 23:58:35 +00:00
2023-09-07 01:33:16 +00:00
unload_model ( )
2023-09-06 23:58:35 +00:00
def train ( ) :
engine . train ( )
2023-09-07 22:08:38 +00:00
t = trange ( steps )
2023-09-06 23:58:35 +00:00
for i in t :
2024-09-06 01:43:20 +00:00
texts , proms , resps , tasks = sample_data ( )
2024-07-18 23:46:45 +00:00
2023-09-06 23:58:35 +00:00
stats = { " step " : i }
2024-09-06 01:43:20 +00:00
stats | = engine . traverse ( text_list = texts , proms_list = proms , resps_list = resps , task_list = tasks )
2024-03-02 02:18:43 +00:00
stats | = { " grad_norm " : engine . get_global_grad_norm ( ) }
2023-09-06 23:58:35 +00:00
tqdm . write ( f " { stats } " )
2024-06-07 01:51:31 +00:00
"""
2023-12-23 01:27:36 +00:00
torch . save ( {
' module ' : model . state_dict ( )
2024-06-04 05:07:00 +00:00
} , f " ./data/ { cfg . model . arch_type } .pth " )
2024-06-07 01:51:31 +00:00
"""
2023-12-23 01:27:36 +00:00
2024-06-05 04:48:51 +00:00
#sample("init", 5)
2023-09-06 23:58:35 +00:00
train ( )
2024-08-04 03:10:21 +00:00
"""
if cfg . optimizations . compile :
model = ml . compile_model ( model , backend = cfg . optimizations . compile )
"""
2024-07-18 23:46:45 +00:00
2024-09-06 19:30:12 +00:00
"""
2024-09-06 01:43:20 +00:00
for task in available_tasks :
2024-07-18 23:46:45 +00:00
sample ( " final " , task = task )
2024-09-06 19:30:12 +00:00
"""
sample ( " final " , task = available_tasks )
2023-09-06 23:58:35 +00:00
2024-07-25 21:50:47 +00:00
engines . quit ( )
2023-09-06 23:58:35 +00:00
if __name__ == " __main__ " :
2024-06-06 01:30:43 +00:00
example_usage ( )