2024-06-09 01:30:15 +00:00
"""
# an AR + NAR model that handles:
* inferencing the primary RVQ level in an autoregressive manner ( AR )
* inferencing the remaining RVQ levels in parallel ( NAR )
This model can fully handle being trained as a unified model ( AR + NAR ) or separate models ( AR | NAR ) .
It ' s recommended to train as a unified model, then " distill " knowledge of each tasks separately, just in case.
"""
2023-09-06 23:58:35 +00:00
from . base import Base , list_to_tensor , Categorical
2023-09-07 01:33:16 +00:00
from . . config import cfg
2023-09-06 23:58:35 +00:00
import torch
from torch . nn . utils . rnn import pad_sequence
import random
2023-09-13 18:19:11 +00:00
import math
2024-11-02 16:49:05 +00:00
import time
2023-09-06 23:58:35 +00:00
from einops import rearrange
from torch import Tensor
2024-11-12 02:21:16 +00:00
from tqdm import trange , tqdm
2024-10-05 03:18:20 +00:00
2024-08-29 18:27:16 +00:00
import logging
_logger = logging . getLogger ( __name__ )
2023-09-06 23:58:35 +00:00
2024-10-18 22:19:52 +00:00
from . . emb . qnt import trim , encode_as_embedding , get_silence
2024-11-12 02:21:16 +00:00
from . . utils import get_devices , setup_logging , timer , clamp , convert_kwargs
2023-10-10 03:03:58 +00:00
2024-06-18 02:45:03 +00:00
from . lora import enable_lora
2024-11-19 16:30:05 +00:00
from . . samplers import cfg_logits
2024-06-18 02:45:03 +00:00
2024-11-10 18:19:48 +00:00
text_task = [ " stt " ]
2023-09-06 23:58:35 +00:00
class AR_NAR ( Base ) :
2024-12-06 05:05:52 +00:00
# yikes
def forward_super ( self , * args , * * kwargs ) :
return super ( ) . forward ( * args , * * kwargs )
2024-11-13 04:30:09 +00:00
# parse inputs for training
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
2024-11-10 18:19:48 +00:00
def forward_train (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
) :
# deduce batch_size
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
# specifies how to sample probabilities of which RVQ levels to train against
rvq_levels_p = self . config . experimental . rvq_levels_p if self . config is not None else " equal "
# determines which RVQ level to target per batch
quant_level_range = self . config . experimental . rvq_level_range if self . config is not None and self . config . experimental . rvq_level_range else [ 0 if self . causal else 1 , self . n_resp_levels - 1 ]
# rate to perform token dropout errors
token_dropout_error = self . config . experimental . token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self . config . experimental . token_dropout_rvq_levels
# RVQ levels to apply masking training on
masking_train_rvq_levels = self . config . experimental . masking_train_rvq_levels
# CFG
cfg_text_dropout_p = self . config . experimental . cfg_text_dropout_p if self . config is not None else 0.0
cfg_cond_dropout_p = self . config . experimental . cfg_cond_dropout_p if self . config is not None else 0.0
cfg_prom_dropout_p = self . config . experimental . cfg_prom_dropout_p if self . config is not None else 0.0
# rate to train RVQ level AR-ly or NAR-ly
masking_train_p = self . config . experimental . masking_train_p if self . config is not None else 0.5
2024-11-17 23:04:07 +00:00
masking_ratio = self . config . experimental . masking_ratio if self . config is not None else " random "
2024-11-13 04:30:09 +00:00
# force set mask training
if " len " not in self . capabilities :
masking_train_p = 0.0
elif " ar " not in self . capabilities :
masking_train_p = 1.0
2024-11-10 18:19:48 +00:00
# implicitly set it to all levels
if not token_dropout_rvq_levels :
token_dropout_rvq_levels = [ 0 , self . resp_levels - 1 ]
if not token_dropout_rvq_levels :
token_dropout_rvq_levels = [ 0 , 0 ]
# allow passing a specific distribution of RVQ levels
rvq_levels_p = rvq_levels_p if isinstance ( rvq_levels_p , list ) else [ ]
if not rvq_levels_p :
lo , hi = quant_level_range [ 0 ] , quant_level_range [ 1 ] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if rvq_levels_p == " equal " :
rvq_levels_p = [ i for i in range ( lo , hi ) ]
else :
# yuck
rvq_levels_p = sum ( [ [ i for _ in range ( hi - i ) ] for i in range ( lo , hi ) ] , [ ] )
# input RVQ levels
quant_levels = [ random . choice ( rvq_levels_p ) for i in range ( batch_size ) ]
# timestep levels (for TTS NAR)
timesteps = [ None for _ in range ( batch_size ) ]
for i , task in enumerate ( task_list ) :
lo , hi = masking_train_rvq_levels [ 0 ] , masking_train_rvq_levels [ 1 ]
if task in text_task :
quant_levels [ i ] = 0 # self.n_resp_levels - 1
elif lo < = quant_levels [ i ] and quant_levels [ i ] < = hi and random . random ( ) < masking_train_p :
2024-11-12 22:41:58 +00:00
# to-do: prioritize lower timesteps over later timesteps
# ...except that the masking rate is still tied to the cosine scheduling, which does this already
#r = random.random()
#p = math.acos(r) / (math.pi * 0.5)
#timesteps[i] = 1.0 - clamp(p, 0.0, 1.0)
2024-11-10 18:19:48 +00:00
timesteps [ i ] = random . random ( )
2024-11-17 23:04:07 +00:00
# instead make it between [0.2, 0.8]
if masking_ratio == " rand " :
timesteps [ i ] = ( timesteps [ i ] * 0.6 ) + 0.2
2024-11-13 17:38:58 +00:00
2024-11-10 18:19:48 +00:00
# trim resps to only contain all levels below the target level
resps_list = [ r if t in text_task else r [ . . . , : l + 1 ] for r , l , t in zip ( resps_list , quant_levels , task_list ) ]
# tensor to cat for RVQ level 0
text_stop_sequence = torch . tensor ( [ 2 ] , device = device , dtype = torch . int16 )
text_start_stop_sequence = torch . tensor ( [ 1 , 2 ] , device = device , dtype = torch . int16 )
audio_stop_sequence = torch . tensor ( [ [ self . stop_token ] ] , device = device , dtype = torch . int16 )
2024-11-13 04:30:09 +00:00
# final validations and stuff
2024-11-10 18:19:48 +00:00
for i , quant_level , resps , proms , task in zip ( range ( batch_size ) , quant_levels , resps_list , proms_list , task_list ) :
# cap quant_level if it exceeds its corresponding resp/prom
2024-11-13 04:30:09 +00:00
# this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9
2024-11-10 18:19:48 +00:00
if quant_level > = resps . shape [ - 1 ] :
quant_levels [ i ] = resps . shape [ - 1 ] - 1
# proms could be a Tensor, list[Tensor], or None
if isinstance ( proms , torch . Tensor ) :
if quant_level > = proms . shape [ - 1 ] :
quant_levels [ i ] = proms . shape [ - 1 ] - 1
elif isinstance ( proms , list ) :
for j , prom in enumerate ( proms ) :
if not isinstance ( prom , torch . Tensor ) :
continue
if quant_level > = prom . shape [ - 1 ] :
quant_levels [ i ] = prom . shape [ - 1 ] - 1
# apply token dropout error compensation
if token_dropout_error > 0 and ( token_dropout_rvq_levels [ 0 ] < = quant_level and quant_level < = token_dropout_rvq_levels [ 1 ] ) :
steps = resps . shape [ 0 ]
for l in range ( quant_level ) :
for t in range ( steps ) :
token = resps [ t , l ] . item ( )
if random . random ( ) < token_dropout_error :
offset = 1 * ( 1 if random . random ( ) < 0.5 else - 1 )
resps_list [ i ] [ t , l ] = clamp ( token + offset , 1 , 1022 ) # +- 1
# only apply stop token for RVQ level 0
2024-11-15 04:17:47 +00:00
if quant_level < = 0 and timesteps [ i ] is None :
2024-11-10 18:19:48 +00:00
# append stop tokens for AR
if task in text_task :
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
. . .
else :
resps_list [ i ] = torch . cat ( [ resps , audio_stop_sequence ] )
2024-11-10 18:48:41 +00:00
if task == " len " :
quant_levels [ i ] = 0
2024-11-10 18:19:48 +00:00
# apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task + [ " len " ] :
drop_text = False
drop_audio = False
if random . random ( ) < cfg_prom_dropout_p :
drop_audio = True
if random . random ( ) < cfg_cond_dropout_p :
drop_audio = True
drop_text = True
if drop_text :
text_list [ i ] = text_start_stop_sequence
if drop_audio :
proms_list [ i ] = None
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
task_list = task_list ,
time_list = timesteps ,
quant_levels = quant_levels ,
)
return super ( ) . forward (
inputs = inputs ,
quant_levels = quant_levels ,
)
2024-11-12 02:21:16 +00:00
def forward_nar_masked (
2023-09-06 23:58:35 +00:00
self ,
2024-11-12 02:21:16 +00:00
2023-09-06 23:58:35 +00:00
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
2024-07-18 21:16:14 +00:00
task_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
lang_list : list [ Tensor ] | None = None ,
2024-04-16 00:54:32 +00:00
tone_list : list [ Tensor ] | None = None ,
2024-06-08 20:42:02 +00:00
len_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
2024-11-12 02:21:16 +00:00
disable_tqdm = False ,
use_lora = None ,
* * sampling_kwargs ,
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
2024-06-29 15:16:37 +00:00
2024-11-12 02:21:16 +00:00
# special "scheduling" to inference RVQ-level 0
level = 0
if cfg . lora is not None :
enable_lora ( self , cfg . lora . active_level ( level ) if use_lora is None else use_lora )
2024-11-16 21:49:06 +00:00
# to-do: check if gumbel sampling works / helps
2024-11-19 16:30:05 +00:00
"""
2024-11-12 02:21:16 +00:00
def log ( x , eps = 1e-20 ) :
return torch . log ( x . clamp ( min = eps ) )
def gumbel_sample ( x , temperature = 1. , dim = - 1 ) :
return ( ( x / max ( temperature , 1e-10 ) ) + - log ( - log ( torch . zeros_like ( x ) . uniform_ ( 0 , 1 ) ) ) ) . argmax ( dim = dim )
2024-11-16 21:49:06 +00:00
"""
2024-11-12 02:21:16 +00:00
2024-11-19 16:30:05 +00:00
def log ( t , eps = 1e-10 ) :
return torch . log ( t + eps )
def gumbel_noise ( t ) :
noise = torch . zeros_like ( t ) . uniform_ ( 0 , 1 )
return - log ( - log ( noise ) )
def gumbel_sample ( t , temperature = 1.0 , dim = - 1 ) :
return ( ( t / max ( temperature , 1e-10 ) ) + gumbel_noise ( t ) ) . argmax ( dim = dim )
2024-11-12 02:21:16 +00:00
# convert (N)AR specific args
sampling_kwargs = convert_kwargs ( sampling_kwargs , " ar_ " )
2024-11-16 21:49:06 +00:00
min_length = sampling_kwargs . pop ( " min_duration " , 1 )
2024-11-12 02:21:16 +00:00
max_length = sampling_kwargs . pop ( " max_duration " , 500 )
max_steps = sampling_kwargs . get ( " max_steps " , 25 )
2024-11-12 18:49:53 +00:00
refine_on_stop = sampling_kwargs . get ( " refine_on_stop " , False )
entropix_sampling = sampling_kwargs . get ( " entropix_sampling " , False )
2024-12-04 15:30:29 +00:00
annealed_sampling = sampling_kwargs . get ( " annealed_sampling " , True )
2024-11-12 02:21:16 +00:00
2024-11-20 00:51:17 +00:00
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
temperature = sampling_kwargs . pop ( " temperature " , 0.0 )
# this really helps keep audio coherent so far
cfg_strength = sampling_kwargs . get ( " cfg_strength " , 2.0 )
cfg_rescale = sampling_kwargs . pop ( " cfg_rescale " , 0.75 )
2024-11-12 02:21:16 +00:00
start_noise = sampling_kwargs . get ( " denoise_start " , 0.0 )
end_noise = sampling_kwargs . get ( " denoise_end " , 1.0 )
2024-12-04 15:30:29 +00:00
remasking = sampling_kwargs . get ( " remasking " , True )
2024-11-12 02:21:16 +00:00
max_steps = math . floor ( max_steps * ( end_noise - start_noise ) )
2024-11-16 21:49:06 +00:00
len_list = [ clamp ( l , min_length , max_length ) for l in len_list ]
2024-11-20 20:22:12 +00:00
2024-12-04 01:40:05 +00:00
# force set CFG because too low / no CFG causes issues
2024-12-04 15:30:29 +00:00
minimum_cfg_strength = sampling_kwargs . get ( " minimum_cfg_strength " , 3.0 )
2024-12-05 02:31:44 +00:00
original_cfg_strength = cfg_strength
2024-12-04 15:30:29 +00:00
cfg_strength = max ( cfg_strength , minimum_cfg_strength )
2024-11-12 02:21:16 +00:00
2024-12-05 02:31:44 +00:00
prefix_context = sampling_kwargs . get ( " prefix_context " , None )
# we can get away with just providing a list of resps to prefix later, and it will magically get removed anyways when masking and scoring
if prefix_context is not None :
text_list = [ torch . concat ( [ prefix [ : - 1 ] , text [ 1 : ] ] ) for prefix , text in zip ( prefix_context [ 0 ] , text_list ) ]
prefix_resps_list = [ resps if resps . dim ( ) == 1 else resps [ : , 0 ] for resps in prefix_context [ 1 ] ]
2024-11-12 02:21:16 +00:00
# if we're denoising from an existing sequence
if start_noise > 0.0 and resps_list is not None :
2024-11-20 20:22:12 +00:00
# flatten if needed
resps_list = [ resps if resps . dim ( ) == 1 else resps [ : , 0 ] for resps in resps_list ]
# gen masking ratio
2024-11-12 02:21:16 +00:00
noise_p = math . cos ( start_noise * math . pi * 0.5 )
2024-11-20 20:22:12 +00:00
# generate scoring mask (because the above mask will get masked off per the scores, so we do not need to mask beforehand)
scores = [ torch . tensor ( [ 1.0 if random . random ( ) < noise_p else 0.0 for _ in range ( seq_len ) ] , dtype = torch . float32 , device = device ) for seq_len in len_list ]
2024-11-12 02:21:16 +00:00
else :
2024-11-20 20:22:12 +00:00
# fill with masked tokens (even though they get masked anyways)
2024-11-12 02:21:16 +00:00
resps_list = [ torch . ones ( ( seq_len , ) , dtype = torch . int16 , device = device ) * self . stop_token for seq_len in len_list ]
2024-11-20 20:22:12 +00:00
# fill scores
scores = [ torch . ones ( ( seq_len , ) , dtype = torch . float32 , device = device ) for seq_len in len_list ]
2024-11-19 03:29:28 +00:00
2024-11-12 02:21:16 +00:00
quant_levels = [ level for _ in range ( batch_size ) ]
null_text = [ torch . tensor ( [ 1 , 2 ] , device = device , dtype = torch . int16 ) for _ in range ( batch_size ) ]
null_prom = [ None for _ in range ( batch_size ) ]
2024-12-05 02:31:44 +00:00
iterator = tqdm ( torch . linspace ( start_noise , end_noise , max_steps ) , desc = " NAR Masked " , disable = disable_tqdm )
for timestep in iterator :
2024-11-20 20:22:12 +00:00
# update previous list of tokens
prev_list = resps_list
2024-11-16 21:49:06 +00:00
# ramp down over time
2024-11-15 04:17:47 +00:00
annealing = 1.0 - timestep
2024-11-12 02:21:16 +00:00
# get noise level, per cosine scheduling
noise_p = math . cos ( timestep * math . pi * 0.5 )
2024-12-04 15:30:29 +00:00
# proportion of tokens to remask
remask_p = 1.0 / max_steps if remasking else 0
2024-11-12 02:21:16 +00:00
# pick the worst scoring tokens to mask off
2024-12-04 15:30:29 +00:00
masked_indices = [ score . topk ( clamp ( int ( noise_p * seq_len + remask_p * seq_len ) , 1 , seq_len ) , dim = - 1 ) . indices for score , seq_len in zip ( scores , len_list ) ]
2024-11-12 02:21:16 +00:00
# mask off inputs
resps_list = [ resp . scatter ( 0 , indices , self . stop_token ) for resp , indices in zip ( resps_list , masked_indices ) ]
# boolean mask
is_masked = [ resps == self . stop_token for resps in resps_list ]
2024-11-13 15:07:10 +00:00
# timestep inputs
2024-11-12 02:21:16 +00:00
time_list = [ timestep for _ in range ( batch_size ) ]
2024-12-04 15:30:29 +00:00
sampling_temperature = temperature * annealing if annealed_sampling else temperature
sampling_cfg = cfg_strength * timestep if annealed_sampling else temperature
# avoid useless CFG sampling
if sampling_cfg < minimum_cfg_strength * 0.5 :
sampling_cfg = 0
2024-11-15 04:17:47 +00:00
2024-12-05 02:31:44 +00:00
if prefix_context is not None :
input_resps_list = [ torch . concat ( [ prefix , resps ] ) for prefix , resps in zip ( prefix_resps_list , resps_list ) ]
# originally requested no CFG, safe to ignore if we have a prefix
if original_cfg_strength == 0 :
sampling_cfg = 0
else :
input_resps_list = resps_list
2024-11-12 02:21:16 +00:00
# setup inputs
inputs = super ( ) . inputs (
text_list = text_list ,
proms_list = proms_list ,
2024-12-05 02:31:44 +00:00
resps_list = input_resps_list ,
2024-11-12 02:21:16 +00:00
lang_list = lang_list ,
tone_list = tone_list ,
time_list = time_list ,
quant_levels = quant_levels ,
)
output = super ( ) . forward (
inputs = inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = output . logits
if cfg_strength > 0 :
null_inputs = super ( ) . inputs (
text_list = null_text ,
proms_list = null_prom ,
2024-12-05 02:31:44 +00:00
resps_list = input_resps_list ,
2024-11-12 02:21:16 +00:00
lang_list = lang_list ,
tone_list = tone_list ,
time_list = time_list ,
quant_levels = quant_levels ,
)
null_output = super ( ) . forward (
inputs = null_inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
2024-11-19 16:30:05 +00:00
logits = cfg_logits ( logits = output . logits , null = null_output . logits , strength = cfg_strength , rescale = cfg_rescale , lens = [ l for l in len_list ] )
2024-11-12 02:21:16 +00:00
# sample with sampler settings
filtered_sampled = super ( ) . sample (
logits = logits ,
prev_list = prev_list ,
quant_levels = quant_levels ,
2024-11-15 04:17:47 +00:00
temperature = sampling_temperature ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
)
# retrieves unfiltered logits
unfiltered_sampled = super ( ) . sample (
logits = logits ,
prev_list = prev_list ,
quant_levels = quant_levels ,
temperature = 0.0 ,
* * sampling_kwargs ,
)
2024-11-19 16:30:05 +00:00
# get sampled tokens
2024-11-12 22:41:58 +00:00
sampled_ids = filtered_sampled . ids
2024-11-12 02:21:16 +00:00
# keep unmasked tokens
resps_list = [ torch . where ( masked , input_ids , resps ) for masked , input_ids , resps in zip ( is_masked , sampled_ids , resps_list ) ]
2024-11-19 18:24:33 +00:00
# get probability scores
scores = [
# conjugate to have worse scoring tokens picked for topk
1.0 -
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
torch . where ( masked , torch . tensor ( [ score for index , score in enumerate ( scores ) ] , device = device ) , torch . ones ( masked . shape , device = device ) )
2024-11-20 00:51:17 +00:00
# use unmodified logit scores for this, as it offers better stability
2024-11-19 18:24:33 +00:00
for scores , masked in zip ( unfiltered_sampled . scores , is_masked )
]
2024-11-12 02:21:16 +00:00
return resps_list
def forward_nar (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
2024-10-12 02:18:26 +00:00
2024-07-20 01:49:40 +00:00
disable_tqdm = False ,
2024-10-11 00:04:12 +00:00
use_lora = None ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2023-09-06 23:58:35 +00:00
) :
2024-11-10 18:19:48 +00:00
# deduce batch_size
2024-09-06 01:43:20 +00:00
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
2024-11-12 02:21:16 +00:00
# convert NAR specific args
sampling_kwargs = convert_kwargs ( sampling_kwargs , " nar_ " )
2024-11-12 02:27:38 +00:00
max_levels = sampling_kwargs . get ( " max_levels " , 0 )
cfg_strength = sampling_kwargs . get ( " cfg_strength " , 0.0 )
2024-11-19 16:30:05 +00:00
cfg_rescale = sampling_kwargs . pop ( " cfg_rescale " , 0.7 )
2024-11-12 02:27:38 +00:00
2024-11-10 18:19:48 +00:00
if max_levels == 0 :
max_levels = self . n_max_levels - 1
2023-09-06 23:58:35 +00:00
2024-12-05 02:31:44 +00:00
# prefixed context provided
"""
prefix_context = sampling_kwargs . get ( " prefix_context " , None )
if prefix_context is not None :
prefix_text , prefix_resps , _ = prefix_context
# to-do: check if we actually need to drop the middle "<eos><bos>"
text_list = [ torch . concat ( [ prefix [ : - 1 ] , text [ 1 : ] ] ) for prefix , text in zip ( prefix_text , text_list ) ]
# feeding this into the NAR-len should automatically handle things
resps_list = [ resps for resps in prefix_resps ]
"""
2024-11-12 02:21:16 +00:00
"""
2024-11-10 18:19:48 +00:00
sampling_layer_skip_variables = { } if sampling_layer_skip else None
2024-09-06 19:30:12 +00:00
2024-11-10 18:19:48 +00:00
if sampling_layer_skip :
if sampling_layer_skip_entropy_threshold > = 0 :
sampling_layer_skip_variables [ " entropy_threshold " ] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold > = 0 :
sampling_layer_skip_variables [ " varentropy_threshold " ] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer > = 0 :
sampling_layer_skip_variables [ " max_layer " ] = sampling_layer_skip_exit_layer
2024-11-12 02:21:16 +00:00
"""
2023-09-06 23:58:35 +00:00
2024-11-10 18:19:48 +00:00
# inference NAR level 0
if len_list is not None :
2024-11-12 02:21:16 +00:00
resps_list = self . forward_nar_masked (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
* * sampling_kwargs ,
)
2024-11-10 18:19:48 +00:00
# expand if given a raw 1D tensor
for i , resp in enumerate ( resps_list ) :
if resp . dim ( ) == 1 :
resps_list [ i ] = resp . unsqueeze ( - 1 )
prev_list = resps_list
2024-06-30 00:46:11 +00:00
2024-11-12 02:27:38 +00:00
null_text = [ torch . tensor ( [ 1 , 2 ] , device = device , dtype = torch . int16 ) for _ in range ( batch_size ) ]
null_prom = [ None for _ in range ( batch_size ) ]
2024-12-05 02:31:44 +00:00
iterator = trange ( max_levels , desc = " NAR " , disable = disable_tqdm )
for n in iterator :
2024-11-10 18:19:48 +00:00
level = prev_list [ 0 ] . shape [ - 1 ]
if level > = max_levels + 1 : # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
2024-12-05 02:31:44 +00:00
iterator . close ( )
2024-11-10 18:19:48 +00:00
break
2024-06-05 04:23:31 +00:00
2024-11-10 18:19:48 +00:00
if cfg . lora is not None :
enable_lora ( self , cfg . lora . active_level ( level ) if use_lora is None else use_lora )
2024-04-17 02:04:48 +00:00
2024-11-10 18:19:48 +00:00
quant_levels = [ level for _ in range ( batch_size ) ] # torch.full((len(text_list),), level)
2023-09-06 23:58:35 +00:00
2024-11-10 18:19:48 +00:00
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = prev_list ,
lang_list = lang_list ,
tone_list = tone_list ,
quant_levels = quant_levels ,
)
2024-11-03 15:58:29 +00:00
2024-11-10 18:19:48 +00:00
output = super ( ) . forward (
inputs = inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits , state = output . logits , output . state
2024-11-03 15:58:29 +00:00
2024-11-12 02:27:38 +00:00
if cfg_strength > 0 :
null_inputs = super ( ) . inputs (
text_list = null_text ,
proms_list = null_prom ,
2024-11-12 19:42:39 +00:00
resps_list = prev_list ,
2024-11-12 02:27:38 +00:00
lang_list = lang_list ,
tone_list = tone_list ,
quant_levels = quant_levels ,
)
null_output = super ( ) . forward (
inputs = null_inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
2024-11-19 16:30:05 +00:00
logits = cfg_logits ( logits = output . logits , null = null_output . logits , strength = cfg_strength , rescale = cfg_rescale , lens = [ resp . shape [ 0 ] for resp in resps_list ] )
2024-11-12 02:27:38 +00:00
2024-11-10 18:19:48 +00:00
sampled = super ( ) . sample (
logits = logits ,
prev_list = prev_list ,
quant_levels = quant_levels ,
2024-11-12 03:40:19 +00:00
#temperature=0.0,
* * ( sampling_kwargs | { " temperature " : 0.0 } ) ,
2024-11-10 18:19:48 +00:00
)
2024-11-12 22:41:58 +00:00
resps_list = sampled . ids
2024-11-10 18:19:48 +00:00
prev_list = [ torch . cat ( [ rs , r . unsqueeze ( - 1 ) . to ( device = device ) ] , dim = - 1 ) for rs , r in zip ( prev_list , resps_list ) ]
return prev_list
def forward_ar (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
disable_tqdm = False ,
use_lora = None ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
) :
# deduce batch_size
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
if cfg . lora is not None :
enable_lora ( self , cfg . lora . active_level ( 0 ) if use_lora is None else use_lora )
2024-11-12 02:21:16 +00:00
# convert AR specific args
sampling_kwargs = convert_kwargs ( sampling_kwargs , " ar_ " )
temperature = sampling_kwargs . get ( " temperature " , 1.0 )
2024-11-12 02:27:38 +00:00
cfg_strength = sampling_kwargs . get ( " cfg_strength " , 0.0 )
2024-11-19 16:30:05 +00:00
cfg_rescale = sampling_kwargs . pop ( " cfg_rescale " , 0.7 )
2024-11-12 02:21:16 +00:00
min_temperature = sampling_kwargs . get ( " min_temperature " , - 1.0 )
max_duration = sampling_kwargs . get ( " max_duration " , 500 )
beam_width = sampling_kwargs . get ( " beam_width " , 0 )
entropix_sampling = sampling_kwargs . get ( " entropix_sampling " , False )
refine_on_stop = sampling_kwargs . get ( " refine_on_stop " , False )
input_prompt_prefix = sampling_kwargs . get ( " input_prompt_prefix " , False )
layer_skip = sampling_kwargs . get ( " layer_skip " , False )
prefix_silence = sampling_kwargs . get ( " prefix_silence " , 0.0 )
mirostat_tau = sampling_kwargs . get ( " mirostat_tau " , 0.0 )
mirostat_eta = sampling_kwargs . get ( " mirostat_eta " , 0.0 )
2024-11-10 18:19:48 +00:00
# inference len
if task_list is not None and task_list [ 0 ] == " len " :
sequence_list = [ torch . tensor ( [ 0 ] , device = device , dtype = torch . int16 ) for _ in range ( batch_size ) ]
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
stop_token = 10
task_list = [ " len " for _ in range ( batch_size ) ]
2024-11-12 02:21:16 +00:00
quant_levels = [ 0 for _ in range ( max ( batch_size , beam_width ) ) ]
2024-11-10 18:19:48 +00:00
2024-12-05 02:31:44 +00:00
iterator = trange ( 10 , desc = " AR " , disable = disable_tqdm )
for n in iterator :
2024-11-10 18:19:48 +00:00
len_list = sequence_list
2023-09-06 23:58:35 +00:00
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
2023-09-13 02:28:07 +00:00
text_list = text_list ,
proms_list = proms_list ,
2024-11-10 18:19:48 +00:00
resps_list = resps_list ,
2023-10-12 02:21:50 +00:00
lang_list = lang_list ,
2024-04-16 00:54:32 +00:00
tone_list = tone_list ,
2024-11-10 18:19:48 +00:00
len_list = len_list ,
task_list = task_list ,
2024-06-05 04:23:31 +00:00
quant_levels = quant_levels ,
2024-04-17 02:04:48 +00:00
)
2024-10-06 03:53:53 +00:00
output = super ( ) . forward (
2024-04-17 02:04:48 +00:00
inputs = inputs ,
2023-09-13 02:28:07 +00:00
quant_levels = quant_levels ,
)
2024-11-10 18:19:48 +00:00
logits = output . logits
r = [ logit [ - 1 : ] . argmax ( dim = 1 ) for logit in logits ]
# sanitize
for i , token in enumerate ( r ) :
if token > 10 :
r [ i ] [ 0 ] = stop_token
# append tokens
for i , ri in enumerate ( r ) :
if stop_token in ri :
stopped [ i ] = True
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , ri . to ( device ) ] )
# stop token found
stopped | = r == stop_token
if stopped . all ( ) . item ( ) :
2024-12-05 02:31:44 +00:00
iterator . close ( )
2024-11-10 18:19:48 +00:00
break
2023-09-06 23:58:35 +00:00
2024-11-10 18:19:48 +00:00
# convert tokens into int
return [ int ( " " . join ( [ str ( token . item ( ) ) for token in r if token != stop_token ] ) ) for r in sequence_list ]
2024-06-18 02:45:03 +00:00
2024-09-06 01:43:20 +00:00
# STT
2024-10-04 23:57:19 +00:00
start_slice = [ 0 for _ in range ( batch_size ) ]
2024-06-08 20:42:02 +00:00
sequence_list = [ torch . zeros ( 0 , device = device ) . to ( torch . int16 ) for _ in range ( batch_size ) ]
2023-09-06 23:58:35 +00:00
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-06-08 20:42:02 +00:00
2024-09-06 19:30:12 +00:00
audio_stop_token = self . stop_token
text_stop_token = 2
2023-09-06 23:58:35 +00:00
2024-06-29 03:44:00 +00:00
state = None
2023-09-18 23:55:41 +00:00
mirostat = [
2024-11-12 02:21:16 +00:00
{ " n " : 1024 , " tau " : mirostat_tau , " eta " : mirostat_eta , " max_surprise " : mirostat_eta * 2 , " error_surprise " : 0 , " running_total_surprise " : 0 }
] * batch_size if mirostat_tau > 0.0 else None
2023-09-06 23:58:35 +00:00
2024-11-12 02:21:16 +00:00
scores = [ 1.0 ] * beam_width
2024-10-23 01:13:54 +00:00
metrics = [ ]
2024-10-12 02:18:26 +00:00
2024-10-23 01:13:54 +00:00
"""
2024-11-02 16:49:05 +00:00
sampling_layer_skip_variables = { } if sampling_layer_skip else None
if sampling_layer_skip :
2024-11-04 00:31:28 +00:00
if sampling_layer_skip_entropy_threshold > = 0 :
sampling_layer_skip_variables [ " entropy_threshold " ] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold > = 0 :
sampling_layer_skip_variables [ " varentropy_threshold " ] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer > = 0 :
sampling_layer_skip_variables [ " max_layer " ] = sampling_layer_skip_exit_layer
2024-11-12 02:21:16 +00:00
"""
2024-11-02 16:49:05 +00:00
2024-09-06 01:43:20 +00:00
for i , sequence in enumerate ( sequence_list ) :
2024-10-04 23:57:19 +00:00
# add <bos> to text for STT
2024-09-06 01:43:20 +00:00
if task_list [ i ] in text_task :
2024-10-04 23:57:19 +00:00
start_slice [ i ] = 1
2024-09-06 01:43:20 +00:00
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , torch . tensor ( [ 1 ] , dtype = torch . int16 , device = device ) ] )
2024-10-04 23:57:19 +00:00
# treat input prompt as initial resp (by prefixing with the prompt instead)
elif input_prompt_prefix :
start_slice [ i ] = proms_list [ i ] . shape [ 0 ]
sequence_list [ i ] , proms_list [ i ] = proms_list [ i ] [ : , 0 ] , sequence_list [ i ]
2024-10-18 22:19:52 +00:00
elif prefix_silence > 0 :
sequence_list [ i ] = get_silence ( prefix_silence , device = sequence_list [ i ] . device )
sequence_list [ i ] = sequence_list [ i ] [ : , 0 ]
# start_slice[i] = sequence_list[i].shape[0]
2024-09-06 01:43:20 +00:00
2024-12-05 02:31:44 +00:00
# prefixed context provided
prefix_context = sampling_kwargs . get ( " prefix_context " , None )
if prefix_context is not None :
prefix_text , prefix_resps , _ = prefix_context
# to-do: check if we actually need to drop the middle "<eos><bos>"
text_list = [ torch . concat ( [ prefix [ : - 1 ] , text [ 1 : ] ] ) for prefix , text in zip ( prefix_text , text_list ) ]
# feeding this into the NAR-len should automatically handle things
sequence_list = [ resps if resps . dim ( ) == 1 else resps [ : , 0 ] for resps in prefix_resps ]
2024-11-12 02:27:38 +00:00
null_text = [ torch . tensor ( [ 1 , 2 ] , device = device , dtype = torch . int16 ) for _ in range ( batch_size ) ]
null_prom = [ None for _ in range ( batch_size ) ]
2023-09-13 18:19:11 +00:00
# get next in sequence
2024-12-05 02:31:44 +00:00
iterator = trange ( max_duration / / max ( 1 , self . causal_size ) , desc = " AR " , disable = disable_tqdm )
for n in iterator :
2024-10-05 03:30:47 +00:00
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
2024-09-06 19:30:12 +00:00
text_list = [ sequence_list [ i ] if task in text_task else text_list [ i ] for i , task in enumerate ( task_list ) ]
resps_list = [ sequence_list [ i ] if task not in text_task else resps_list [ i ] for i , task in enumerate ( task_list ) ]
2024-11-12 19:42:39 +00:00
quant_levels = [ 0 for _ in range ( max ( batch_size , beam_width ) ) ]
2024-09-06 19:30:12 +00:00
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
2024-06-08 20:42:02 +00:00
len_list = len_list ,
task_list = task_list ,
2024-11-12 19:42:39 +00:00
quant_levels = quant_levels ,
2024-04-17 02:04:48 +00:00
)
2024-10-05 03:30:47 +00:00
# to-do: find an elegant way to write this
2024-10-06 03:53:53 +00:00
output = super ( ) . forward (
inputs = inputs ,
state = state ,
2024-11-10 18:19:48 +00:00
#layer_skip_variables=sampling_layer_skip_variables,
2024-11-12 02:21:16 +00:00
output_attentions = entropix_sampling ,
2024-10-06 03:53:53 +00:00
)
2024-11-12 02:27:38 +00:00
if cfg_strength > 0 :
null_inputs = super ( ) . inputs (
text_list = null_text ,
proms_list = null_prom ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
quant_levels = quant_levels ,
)
null_output = super ( ) . forward (
inputs = null_inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
2024-11-19 16:30:05 +00:00
logits = cfg_logits ( logits = output . logits , null = null_output . logits , strength = cfg_strength , rescale = cfg_rescale , lens = [ resp . shape [ 0 ] + 1 for resp in resps_list ] )
2024-11-12 02:27:38 +00:00
2024-10-12 02:18:26 +00:00
logits , state = output . logits , output . state
2023-09-06 23:58:35 +00:00
2024-10-12 02:18:26 +00:00
sampled = super ( ) . sample (
2023-09-13 02:28:07 +00:00
logits = logits ,
2024-11-12 02:21:16 +00:00
prev_list = [ resps_list [ i ] if task not in text_task else text_list [ i ] for i , task in enumerate ( task_list ) ] ,
* * ( sampling_kwargs | { " attentions " : output . attentions if entropix_sampling else None } ) ,
2023-09-13 02:28:07 +00:00
)
2024-11-12 22:41:58 +00:00
ids = sampled . ids
2024-10-12 02:18:26 +00:00
2024-11-02 16:49:05 +00:00
if cfg . experimental :
if sampled . entropy :
metrics . append ( sampled . entropy )
elif sampled . scores :
2024-11-08 03:19:14 +00:00
#metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
metrics . append ( [ { " p " : p [ 0 ] } for p in sampled . scores ] )
2024-10-12 02:18:26 +00:00
2023-09-18 23:55:41 +00:00
if mirostat is not None :
2024-10-12 02:18:26 +00:00
mirostat = sampled . scores
2024-11-12 02:21:16 +00:00
elif beam_width > 0 :
2023-09-13 18:19:11 +00:00
# expand tuple
2024-10-23 05:03:35 +00:00
s = sampled . scores
2023-09-13 18:19:11 +00:00
# first step, expand batch
if batch_size == 1 :
2024-11-12 02:21:16 +00:00
batch_size = beam_width
text_list = text_list * beam_width
proms_list = proms_list * beam_width
sequence_list = sequence_list * beam_width
task_list = task_list * beam_width
start_slice = start_slice * beam_width
2023-09-13 18:19:11 +00:00
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-10-23 05:03:35 +00:00
scores = [ scores [ i ] + score for i , score in enumerate ( s ) ]
2023-09-13 02:28:07 +00:00
2023-09-06 23:58:35 +00:00
# append tokens
2024-11-12 22:41:58 +00:00
for i , token in enumerate ( ids ) :
2024-09-06 19:30:12 +00:00
task = task_list [ i ]
stop_token = audio_stop_token if task not in text_task else text_stop_token
2024-11-12 22:41:58 +00:00
if stop_token in token :
2023-09-06 23:58:35 +00:00
stopped [ i ] = True
2024-11-12 22:41:58 +00:00
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , token . to ( device ) ] )
2023-09-06 23:58:35 +00:00
# stop token found
2024-09-06 19:30:12 +00:00
# stopped |= r == stop_token
2023-09-06 23:58:35 +00:00
if stopped . all ( ) . item ( ) :
2024-12-05 02:31:44 +00:00
iterator . close ( )
2023-09-06 23:58:35 +00:00
break
2024-11-04 00:31:28 +00:00
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
2024-11-12 02:21:16 +00:00
"""
2024-10-23 01:13:54 +00:00
if metrics :
from . . plot import plot_sample_metrics
2024-11-02 16:49:05 +00:00
filename = " metrics "
2024-11-12 02:21:16 +00:00
if entropix_sampling :
filename + = f ' [entropix_sampling] '
2024-11-02 16:49:05 +00:00
if sampling_layer_skip_exit_layer > = 0 :
filename + = f ' [ { sampling_layer_skip_exit_layer + 1 } ] '
plot_sample_metrics ( metrics , filename = f ' { filename } .png ' )
2024-11-12 02:21:16 +00:00
"""
2024-10-12 02:18:26 +00:00
2023-09-13 18:19:11 +00:00
# pick the best scoring candidate
# desu this is always going to be candidate 0
2024-11-12 02:21:16 +00:00
if beam_width :
2024-10-23 03:06:22 +00:00
sequence_list = sequence_list [ : 1 ]
task_list = task_list [ : 1 ]
2023-09-13 02:28:07 +00:00
2024-09-06 19:30:12 +00:00
# remove stop token
sequence_list = [ self . _prune ( r , audio_stop_token if task_list [ i ] not in text_task else text_stop_token ) for i , r in enumerate ( sequence_list ) ]
# remove <bos>
2024-10-04 23:57:19 +00:00
sequence_list = [ sequence_list [ i ] [ start_slice [ i ] : ] for i , task in enumerate ( task_list ) ]
2024-10-12 15:41:35 +00:00
2024-11-12 02:21:16 +00:00
if refine_on_stop :
2024-11-04 00:31:28 +00:00
# get how much we need to slice from the end
slice_lengths = [ sequence . shape [ - 1 ] for sequence in sequence_list ]
# -1 for the stop token
logits = [ logit [ - length - 1 : - 1 ] for logit , length in zip ( logits , slice_lengths ) ]
# greedy sample from the sequence
refined_list = [ logit . argmax ( dim = - 1 ) for logit in logits ]
# to-do: compare scores
# set the "refined" list as the output
2024-12-05 02:31:44 +00:00
sequence_list = refined_list
# slice out prefix
if prefix_context is not None :
prefix_text , prefix_resps , prefix_lens = prefix_context
sequence_list = [ resps [ l : ] for resps , l in zip ( sequence_list , prefix_lens ) ]
2024-11-04 00:31:28 +00:00
2024-06-08 20:42:02 +00:00
return sequence_list
2023-09-06 23:58:35 +00:00
2024-11-10 18:19:48 +00:00
def forward (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
2024-12-06 05:05:52 +00:00
training : bool | None = None ,
2024-11-10 18:19:48 +00:00
disable_tqdm = False ,
use_lora = None ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
) :
# deduce batch_size
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
# generate task list if not provided
if task_list is None :
task_list = [ default_task for _ in range ( batch_size ) ]
# implicitly set for training
if training is None and text_list is not None and resps_list is not None :
n_levels_set = { r . shape [ - 1 ] for r in resps_list }
n_levels = next ( iter ( n_levels_set ) )
training = n_levels == self . n_resp_levels
# is training
if training :
return self . forward_train (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
)
# is NAR
if ( len_list is not None or resps_list is not None ) and text_list is not None :
return self . forward_nar (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
2024-12-05 02:31:44 +00:00
disable_tqdm = disable_tqdm ,
use_lora = use_lora ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
)
# is AR
return self . forward_ar (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
2024-12-05 02:31:44 +00:00
disable_tqdm = disable_tqdm ,
use_lora = use_lora ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
)
2023-09-06 23:58:35 +00:00
def example_usage ( ) :
2024-11-10 18:19:48 +00:00
cfg . device = " cuda "
2024-08-10 02:15:01 +00:00
cfg . trainer . backend = " local "
2024-05-25 16:07:52 +00:00
if cfg . audio_backend == " dac " :
2024-07-04 20:58:08 +00:00
cfg . sample_rate = 44_100
2023-09-06 23:58:35 +00:00
2024-04-21 19:58:04 +00:00
from functools import partial
2023-09-06 23:58:35 +00:00
from einops import repeat
2024-04-21 19:58:04 +00:00
from tqdm import tqdm
2023-09-06 23:58:35 +00:00
2024-07-18 23:46:45 +00:00
from . . emb . qnt import decode_to_file , unload_model , trim_random , repeat_extend_audio , concat_audio , merge_audio
2024-07-25 21:50:47 +00:00
from . . engines import Engine , Engines
2023-09-07 22:08:38 +00:00
from . . utils import wrapper as ml
2024-09-06 04:21:18 +00:00
from . . utils import setup_logging
2024-04-21 19:58:04 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2024-11-13 19:31:17 +00:00
2024-11-15 04:17:47 +00:00
# cfg.model.experimental.masking_train_p = 0.5
2024-11-13 19:31:17 +00:00
cfg . hyperparameters . batch_size = 1
cfg . hyperparameters . gradient_accumulation_steps = 1
2023-09-06 23:58:35 +00:00
2024-09-06 04:21:18 +00:00
setup_logging ( )
2023-09-06 23:58:35 +00:00
2024-11-10 04:57:34 +00:00
def load_artifact ( path ) :
artifact = np . load ( path , allow_pickle = True ) [ ( ) ]
2024-11-10 18:19:48 +00:00
text = torch . tensor ( cfg . tokenizer . encode ( artifact [ " metadata " ] [ " phonemes " ] ) ) . to ( dtype = torch . uint8 , device = cfg . device )
audio = torch . from_numpy ( artifact [ " codes " ] . astype ( np . int16 ) ) [ 0 , : , : ] . t ( ) . to ( dtype = torch . int16 , device = cfg . device )
2024-11-10 04:57:34 +00:00
return text , audio
text , audio = load_artifact ( f " ./data/qnt. { ' dac ' if cfg . audio_backend == ' dac ' else ' enc ' } " )
2024-11-10 18:19:48 +00:00
batch_size = cfg . hyperparameters . batch_size
2024-11-10 04:57:34 +00:00
2024-11-10 18:19:48 +00:00
text_list = [ text ] * batch_size
proms_list = [ audio [ : cfg . dataset . frames_per_second , : ] ] * batch_size
2024-11-15 04:17:47 +00:00
resps_list = [ audio [ : cfg . dataset . frames_per_second * 4 , : ] ] * batch_size
2024-07-18 23:46:45 +00:00
2023-09-06 23:58:35 +00:00
kwargs = {
2024-06-06 00:50:06 +00:00
' n_text_tokens ' : 256 ,
' n_audio_tokens ' : 1024 ,
2024-08-05 03:03:22 +00:00
' d_model ' : 1024 , # 256, # 1024, # 1536
2023-12-21 00:45:58 +00:00
' n_heads ' : 16 , # 4, # 16, # 24
2024-06-05 03:41:22 +00:00
' n_layers ' : 12 , # 32
2024-08-05 01:25:06 +00:00
' n_experts ' : 1 if not cfg . model else cfg . model . experts ,
2024-04-09 01:14:51 +00:00
2024-05-19 21:40:14 +00:00
' p_dropout ' : 0.1 ,
2024-05-11 22:14:05 +00:00
2024-05-03 01:08:59 +00:00
' l_padding ' : 8 if cfg . optimizations . fp8 else 0 ,
2024-04-30 03:14:01 +00:00
' config ' : cfg . model
2023-12-21 00:45:58 +00:00
}
2023-09-06 23:58:35 +00:00
2024-07-18 23:46:45 +00:00
bos_id , space_id , eos_id = cfg . tokenizer . encode ( " " )
2024-11-15 04:17:47 +00:00
available_tasks = [ ] + ( [ " tts-ar " ] if " ar " in cfg . model . capabilities else [ ] ) + ( [ " tts-nar " ] if " len " in cfg . model . capabilities else [ ] )
2024-07-18 23:46:45 +00:00
2024-11-10 18:19:48 +00:00
model = AR_NAR ( * * kwargs ) . to ( cfg . device )
2024-11-15 04:17:47 +00:00
steps = 1000 / / batch_size
2024-04-09 01:14:51 +00:00
2024-06-09 16:22:52 +00:00
optimizer = cfg . hyperparameters . optimizer . lower ( ) if cfg . yaml_path is not None else " prodigy "
scheduler = cfg . hyperparameters . scheduler . lower ( ) if cfg . yaml_path is not None else " "
learning_rate = cfg . hyperparameters . learning_rate if cfg . yaml_path is not None else None
2024-05-10 01:28:20 +00:00
if cfg . optimizations . dadaptation :
# do not combine the two
if scheduler == " schedulefree " :
scheduler = " "
2023-09-07 14:14:03 +00:00
2024-05-10 01:28:20 +00:00
learning_rate = 1.0
if optimizer == " prodigy " :
if learning_rate is None :
learning_rate = 1.0
optimizer = ml . Prodigy
elif optimizer == " adagrad " :
if learning_rate is None :
learning_rate = 1.0e-2
optimizer = ml . Adagrad
elif optimizer == " adamw " :
if learning_rate is None :
learning_rate = 1.0e-4
optimizer = ml . AdamW
elif optimizer == " sdg " :
if learning_rate is None :
learning_rate = 1.0e-4
optimizer = ml . SGD
else :
raise ValueError ( f " Unrecognized optimizer: { optimizer } " )
2024-08-29 18:27:16 +00:00
_logger . info ( f " Optimizer: { optimizer } \t Learning rate: { learning_rate } " )
2024-05-10 01:28:20 +00:00
optimizer = optimizer ( model . parameters ( ) , lr = learning_rate )
if scheduler == " schedulefree " :
if isinstance ( optimizer , ml . AdamW ) :
scheduler = ml . schedulefree . AdamWScheduleFree
elif isinstance ( optimizer , ml . SGD ) :
scheduler = ml . schedulefree . SGDScheduleFree
else :
scheduler = None
if scheduler is not None :
2024-08-29 18:27:16 +00:00
_logger . info ( f " Scheduler: { scheduler } " )
2024-05-10 01:28:20 +00:00
optimizer = scheduler ( model . parameters ( ) , lr = learning_rate )
if cfg . optimizations . replace and cfg . optimizations . linear :
model = ml . replace_linear ( model )
if cfg . optimizations . replace and cfg . optimizations . embedding :
model = ml . replace_embedding ( model )
2024-08-02 01:12:06 +00:00
"""
cfg . optimizations . model_offloading = {
" devices " : [ " cuda:0 " , " cpu " ] ,
2024-08-02 01:56:28 +00:00
# "limits": [ 0.9, -1 ],
2024-08-02 03:43:39 +00:00
" assign " : [ [ f ' layers. { i } . ' for i in range ( 0 , 10 ) ] , [ f ' layers. { i } . ' for i in range ( 11 , 12 ) ] + [ " model.norm " ] ] ,
# "limits": [ 256 * (1024 ** 2), -1 ]
2024-08-02 01:12:06 +00:00
}
"""
2024-05-10 01:28:20 +00:00
engine = Engine ( model = model , optimizer = optimizer )
2024-07-25 21:50:47 +00:00
engines = Engines ( { " ar+nar " : engine } )
engines . setup ( )
2024-08-02 01:56:28 +00:00
2024-08-03 03:25:49 +00:00
"""
2024-08-02 01:56:28 +00:00
if cfg . optimizations . model_offloading :
model = ml . offload_model ( model , policy = cfg . optimizations . model_offloading )
2024-08-03 03:25:49 +00:00
"""
2024-07-25 21:50:47 +00:00
2024-06-07 01:51:31 +00:00
"""
2023-09-21 00:10:59 +00:00
torch . save ( {
' module ' : model . state_dict ( )
2024-06-04 05:07:00 +00:00
} , f " ./data/ { cfg . model . arch_type } .pth " )
2024-06-07 01:51:31 +00:00
"""
2023-09-21 00:10:59 +00:00
2024-08-29 18:27:16 +00:00
_logger . info ( f " AR+NAR ( { cfg . model . arch_type } , { cfg . audio_backend } ) parameter count: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) } " )
2024-04-09 01:14:51 +00:00
2024-07-18 23:46:45 +00:00
@torch.no_grad ( )
2024-09-06 01:43:20 +00:00
def sample_data ( t = None ) :
2024-09-06 19:30:12 +00:00
if isinstance ( t , list ) :
tasks = t
2024-11-10 18:19:48 +00:00
texts = [ text_list [ 0 ] . to ( cfg . device ) if task not in text_task else None for i , task in enumerate ( tasks ) ]
proms = [ proms_list [ 0 ] . to ( cfg . device ) if task not in text_task else [ " stt " ] for i , task in enumerate ( tasks ) ]
resps = [ None if task not in text_task else resps_list [ 0 ] . to ( cfg . device ) for i , task in enumerate ( tasks ) ]
2024-09-06 19:30:12 +00:00
return texts , proms , resps , tasks
2024-07-18 23:46:45 +00:00
texts = [ ]
proms = [ ]
resps = [ ]
2024-09-06 01:43:20 +00:00
tasks = [ ]
2024-07-18 23:46:45 +00:00
for i in range ( batch_size ) :
2024-09-06 01:43:20 +00:00
task = random . choice ( available_tasks ) if t is None else t
2024-07-18 23:46:45 +00:00
2024-11-10 18:19:48 +00:00
text = text_list [ i ] . to ( cfg . device )
prom = proms_list [ i ] . to ( cfg . device )
resp = resps_list [ i ] . to ( cfg . device )
2024-07-18 23:46:45 +00:00
# do nothing
2024-11-10 18:19:48 +00:00
if task == " stt " :
prom = [ task ]
else :
2024-11-15 04:17:47 +00:00
task = " tts " if random . random ( ) > 0.1 or " len " not in cfg . model . capabilities else " len "
2024-09-06 19:30:12 +00:00
texts . append ( text )
proms . append ( prom )
resps . append ( resp )
2024-09-06 01:43:20 +00:00
tasks . append ( task )
2024-07-18 23:46:45 +00:00
2024-09-06 01:43:20 +00:00
return texts , proms , resps , tasks
2024-04-18 01:39:35 +00:00
2024-07-18 23:46:45 +00:00
@torch.inference_mode ( )
2024-09-06 01:43:20 +00:00
def sample ( name , steps = 500 , task = None ) :
2023-09-06 23:58:35 +00:00
engine . eval ( )
2024-07-18 23:46:45 +00:00
2024-11-10 18:19:48 +00:00
text_list , proms_list , resp_list , task_list = sample_data ( task )
2024-06-05 04:48:51 +00:00
2024-11-10 18:19:48 +00:00
if task == " tts-nar " :
2024-11-12 02:21:16 +00:00
len_list = engine ( text_list , proms_list , task_list = [ " len " ] , max_steps = 5 , temperature = 0.0 )
2024-11-10 18:19:48 +00:00
len_list = [ resp_list [ 0 ] . shape [ 0 ] for l in len_list ]
2024-11-15 04:17:47 +00:00
resps_list = engine ( text_list , proms_list , len_list = len_list )
2024-09-06 01:43:20 +00:00
else :
2024-11-12 02:21:16 +00:00
resps_list = engine ( text_list , proms_list , task_list = [ " tts " ] , max_duration = steps , temperature = 1.0 )
resps_list = engine ( text_list , proms_list , resps_list = resps_list , temperature = 0.0 )
2024-09-06 19:30:12 +00:00
2024-11-10 18:19:48 +00:00
for i , o in enumerate ( resps_list ) :
_ = decode_to_file ( o . to ( dtype = torch . int32 ) , f " data/ { cfg . model . arch_type } . { cfg . audio_backend } . { i } . { name } . { task } .wav " , device = cfg . device )
2023-09-06 23:58:35 +00:00
2023-09-07 01:33:16 +00:00
unload_model ( )
2023-09-06 23:58:35 +00:00
def train ( ) :
engine . train ( )
2023-09-07 22:08:38 +00:00
t = trange ( steps )
2023-09-06 23:58:35 +00:00
for i in t :
2024-09-06 01:43:20 +00:00
texts , proms , resps , tasks = sample_data ( )
2024-07-18 23:46:45 +00:00
2023-09-06 23:58:35 +00:00
stats = { " step " : i }
2024-11-10 18:19:48 +00:00
stats | = engine . traverse ( text_list = texts , proms_list = proms , resps_list = resps , task_list = tasks , training = True )
2024-03-02 02:18:43 +00:00
stats | = { " grad_norm " : engine . get_global_grad_norm ( ) }
2023-09-06 23:58:35 +00:00
tqdm . write ( f " { stats } " )
2024-06-07 01:51:31 +00:00
"""
2023-12-23 01:27:36 +00:00
torch . save ( {
' module ' : model . state_dict ( )
2024-06-04 05:07:00 +00:00
} , f " ./data/ { cfg . model . arch_type } .pth " )
2024-06-07 01:51:31 +00:00
"""
2023-12-23 01:27:36 +00:00
2024-06-05 04:48:51 +00:00
#sample("init", 5)
2023-09-06 23:58:35 +00:00
train ( )
2024-08-04 03:10:21 +00:00
"""
if cfg . optimizations . compile :
model = ml . compile_model ( model , backend = cfg . optimizations . compile )
"""
2024-07-18 23:46:45 +00:00
2024-09-06 01:43:20 +00:00
for task in available_tasks :
2024-07-18 23:46:45 +00:00
sample ( " final " , task = task )
2023-09-06 23:58:35 +00:00
2024-07-25 21:50:47 +00:00
engines . quit ( )
2023-09-06 23:58:35 +00:00
if __name__ == " __main__ " :
2024-06-06 01:30:43 +00:00
example_usage ( )