2024-06-09 01:30:15 +00:00
"""
# an AR + NAR model that handles:
* inferencing the primary RVQ level in an autoregressive manner ( AR )
* inferencing the remaining RVQ levels in parallel ( NAR )
This model can fully handle being trained as a unified model ( AR + NAR ) or separate models ( AR | NAR ) .
It ' s recommended to train as a unified model, then " distill " knowledge of each tasks separately, just in case.
"""
2023-09-06 23:58:35 +00:00
from . base import Base , list_to_tensor , Categorical
2023-09-07 01:33:16 +00:00
from . . config import cfg
2023-09-06 23:58:35 +00:00
import torch
from torch . nn . utils . rnn import pad_sequence
import random
2023-09-13 18:19:11 +00:00
import math
2024-11-02 16:49:05 +00:00
import time
2023-09-06 23:58:35 +00:00
from einops import rearrange
from torch import Tensor
2024-11-12 02:21:16 +00:00
from tqdm import trange , tqdm
2024-10-05 03:18:20 +00:00
2024-08-29 18:27:16 +00:00
import logging
_logger = logging . getLogger ( __name__ )
2023-09-06 23:58:35 +00:00
2024-10-18 22:19:52 +00:00
from . . emb . qnt import trim , encode_as_embedding , get_silence
2024-11-12 02:21:16 +00:00
from . . utils import get_devices , setup_logging , timer , clamp , convert_kwargs
2023-10-10 03:03:58 +00:00
2024-06-18 02:45:03 +00:00
from . lora import enable_lora
2024-11-10 18:19:48 +00:00
text_task = [ " stt " ]
2023-09-06 23:58:35 +00:00
class AR_NAR ( Base ) :
2024-11-10 18:19:48 +00:00
def forward_train (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
) :
# deduce batch_size
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
# specifies how to sample probabilities of which RVQ levels to train against
rvq_levels_p = self . config . experimental . rvq_levels_p if self . config is not None else " equal "
# determines which RVQ level to target per batch
quant_level_range = self . config . experimental . rvq_level_range if self . config is not None and self . config . experimental . rvq_level_range else [ 0 if self . causal else 1 , self . n_resp_levels - 1 ]
# rate to perform token dropout errors
token_dropout_error = self . config . experimental . token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self . config . experimental . token_dropout_rvq_levels
# RVQ levels to apply masking training on
masking_train_rvq_levels = self . config . experimental . masking_train_rvq_levels
# force set mask training
if " len " not in self . capabilities :
masking_train_rvq_levels = 0.0
elif " ar " not in self . capabilities :
masking_train_rvq_levels = 1.0
# CFG
cfg_text_dropout_p = self . config . experimental . cfg_text_dropout_p if self . config is not None else 0.0
cfg_cond_dropout_p = self . config . experimental . cfg_cond_dropout_p if self . config is not None else 0.0
cfg_prom_dropout_p = self . config . experimental . cfg_prom_dropout_p if self . config is not None else 0.0
# rate to train RVQ level AR-ly or NAR-ly
masking_train_p = self . config . experimental . masking_train_p if self . config is not None else 0.5
# implicitly set it to all levels
if not token_dropout_rvq_levels :
token_dropout_rvq_levels = [ 0 , self . resp_levels - 1 ]
if not token_dropout_rvq_levels :
token_dropout_rvq_levels = [ 0 , 0 ]
# allow passing a specific distribution of RVQ levels
rvq_levels_p = rvq_levels_p if isinstance ( rvq_levels_p , list ) else [ ]
if not rvq_levels_p :
lo , hi = quant_level_range [ 0 ] , quant_level_range [ 1 ] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if rvq_levels_p == " equal " :
rvq_levels_p = [ i for i in range ( lo , hi ) ]
else :
# yuck
rvq_levels_p = sum ( [ [ i for _ in range ( hi - i ) ] for i in range ( lo , hi ) ] , [ ] )
# input RVQ levels
quant_levels = [ random . choice ( rvq_levels_p ) for i in range ( batch_size ) ]
# timestep levels (for TTS NAR)
timesteps = [ None for _ in range ( batch_size ) ]
for i , task in enumerate ( task_list ) :
lo , hi = masking_train_rvq_levels [ 0 ] , masking_train_rvq_levels [ 1 ]
if task in text_task :
quant_levels [ i ] = 0 # self.n_resp_levels - 1
elif lo < = quant_levels [ i ] and quant_levels [ i ] < = hi and random . random ( ) < masking_train_p :
timesteps [ i ] = random . random ( )
# trim resps to only contain all levels below the target level
resps_list = [ r if t in text_task else r [ . . . , : l + 1 ] for r , l , t in zip ( resps_list , quant_levels , task_list ) ]
# tensor to cat for RVQ level 0
text_stop_sequence = torch . tensor ( [ 2 ] , device = device , dtype = torch . int16 )
text_start_stop_sequence = torch . tensor ( [ 1 , 2 ] , device = device , dtype = torch . int16 )
audio_stop_sequence = torch . tensor ( [ [ self . stop_token ] ] , device = device , dtype = torch . int16 )
# I hate python's value/reference semantics so much
for i , quant_level , resps , proms , task in zip ( range ( batch_size ) , quant_levels , resps_list , proms_list , task_list ) :
# cap quant_level if it exceeds its corresponding resp/prom
if quant_level > = resps . shape [ - 1 ] :
quant_levels [ i ] = resps . shape [ - 1 ] - 1
# proms could be a Tensor, list[Tensor], or None
if isinstance ( proms , torch . Tensor ) :
if quant_level > = proms . shape [ - 1 ] :
quant_levels [ i ] = proms . shape [ - 1 ] - 1
elif isinstance ( proms , list ) :
for j , prom in enumerate ( proms ) :
if not isinstance ( prom , torch . Tensor ) :
continue
if quant_level > = prom . shape [ - 1 ] :
quant_levels [ i ] = prom . shape [ - 1 ] - 1
# apply token dropout error compensation
if token_dropout_error > 0 and ( token_dropout_rvq_levels [ 0 ] < = quant_level and quant_level < = token_dropout_rvq_levels [ 1 ] ) :
steps = resps . shape [ 0 ]
for l in range ( quant_level ) :
for t in range ( steps ) :
token = resps [ t , l ] . item ( )
if random . random ( ) < token_dropout_error :
offset = 1 * ( 1 if random . random ( ) < 0.5 else - 1 )
resps_list [ i ] [ t , l ] = clamp ( token + offset , 1 , 1022 ) # +- 1
# only apply stop token for RVQ level 0
if quant_level < = 0 :
# append stop tokens for AR
if task in text_task :
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
. . .
else :
resps_list [ i ] = torch . cat ( [ resps , audio_stop_sequence ] )
2024-11-10 18:48:41 +00:00
if task == " len " :
quant_levels [ i ] = 0
2024-11-10 18:19:48 +00:00
# apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task + [ " len " ] :
drop_text = False
drop_audio = False
if random . random ( ) < cfg_prom_dropout_p :
drop_audio = True
if random . random ( ) < cfg_cond_dropout_p :
drop_audio = True
drop_text = True
if drop_text :
text_list [ i ] = text_start_stop_sequence
if drop_audio :
proms_list [ i ] = None
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
task_list = task_list ,
time_list = timesteps ,
quant_levels = quant_levels ,
)
return super ( ) . forward (
inputs = inputs ,
quant_levels = quant_levels ,
)
2024-11-12 02:21:16 +00:00
def forward_nar_masked (
2023-09-06 23:58:35 +00:00
self ,
2024-11-12 02:21:16 +00:00
2023-09-06 23:58:35 +00:00
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
2024-07-18 21:16:14 +00:00
task_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
lang_list : list [ Tensor ] | None = None ,
2024-04-16 00:54:32 +00:00
tone_list : list [ Tensor ] | None = None ,
2024-06-08 20:42:02 +00:00
len_list : list [ Tensor ] | None = None ,
2023-10-12 01:38:40 +00:00
2024-11-12 02:21:16 +00:00
disable_tqdm = False ,
use_lora = None ,
* * sampling_kwargs ,
) :
device = text_list [ 0 ] . device
batch_size = len ( text_list )
2024-06-29 15:16:37 +00:00
2024-11-12 02:21:16 +00:00
# special "scheduling" to inference RVQ-level 0
level = 0
if cfg . lora is not None :
enable_lora ( self , cfg . lora . active_level ( level ) if use_lora is None else use_lora )
def log ( x , eps = 1e-20 ) :
return torch . log ( x . clamp ( min = eps ) )
def gumbel_sample ( x , temperature = 1. , dim = - 1 ) :
return ( ( x / max ( temperature , 1e-10 ) ) + - log ( - log ( torch . zeros_like ( x ) . uniform_ ( 0 , 1 ) ) ) ) . argmax ( dim = dim )
# convert (N)AR specific args
sampling_kwargs = convert_kwargs ( sampling_kwargs , " ar_ " )
max_length = sampling_kwargs . pop ( " max_duration " , 500 )
max_steps = sampling_kwargs . get ( " max_steps " , 25 )
temperature = sampling_kwargs . pop ( " temperature " , 1.0 )
cfg_strength = sampling_kwargs . get ( " cfg_strength " , 0.0 )
start_noise = sampling_kwargs . get ( " denoise_start " , 0.0 )
end_noise = sampling_kwargs . get ( " denoise_end " , 1.0 )
max_steps = math . floor ( max_steps * ( end_noise - start_noise ) )
len_list = [ clamp ( l , 1 , max_length ) for l in len_list ]
# if we're denoising from an existing sequence
if start_noise > 0.0 and resps_list is not None :
noise_p = math . cos ( start_noise * math . pi * 0.5 )
mask = [ torch . tensor ( [ random . random ( ) < noise_p for _ in range ( seq_len ) ] , dtype = torch . bool , device = device ) for seq_len in len_list ]
resps_list = [ torch . where ( mask , self . stop_token , resps [ : , 0 ] ) for seq_len , resps in zip ( len_list , resps_list ) ]
else :
resps_list = [ torch . ones ( ( seq_len , ) , dtype = torch . int16 , device = device ) * self . stop_token for seq_len in len_list ]
scores = [ torch . zeros ( ( seq_len , ) , dtype = torch . float32 , device = device ) for seq_len in len_list ]
quant_levels = [ level for _ in range ( batch_size ) ]
null_text = [ torch . tensor ( [ 1 , 2 ] , device = device , dtype = torch . int16 ) for _ in range ( batch_size ) ]
null_prom = [ None for _ in range ( batch_size ) ]
prev_list = resps_list
for timestep , steps_until_x0 in tqdm ( zip ( torch . linspace ( start_noise , end_noise , max_steps ) , reversed ( range ( max_steps ) ) ) , desc = " NAR Masked " , disable = disable_tqdm , total = max_steps ) :
# get noise level, per cosine scheduling
noise_p = math . cos ( timestep * math . pi * 0.5 )
# pick the worst scoring tokens to mask off
masked_indices = [ score . topk ( max ( int ( noise_p * seq_len ) , 1 ) , dim = - 1 ) . indices for score , seq_len in zip ( scores , len_list ) ]
# mask off inputs
resps_list = [ resp . scatter ( 0 , indices , self . stop_token ) for resp , indices in zip ( resps_list , masked_indices ) ]
# boolean mask
is_masked = [ resps == self . stop_token for resps in resps_list ]
time_list = [ timestep for _ in range ( batch_size ) ]
# setup inputs
inputs = super ( ) . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
time_list = time_list ,
quant_levels = quant_levels ,
)
output = super ( ) . forward (
inputs = inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = output . logits
if cfg_strength > 0 :
null_inputs = super ( ) . inputs (
text_list = null_text ,
proms_list = null_prom ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
time_list = time_list ,
quant_levels = quant_levels ,
)
null_output = super ( ) . forward (
inputs = null_inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len , logit , null_logit in zip ( len_list , output . logits , null_output . logits ) :
logit [ - seq_len : ] = null_logit [ - seq_len : ] + ( logit [ - seq_len : ] - null_logit [ - seq_len : ] ) * cfg_strength
# sample with sampler settings
filtered_sampled = super ( ) . sample (
logits = logits ,
prev_list = prev_list ,
quant_levels = quant_levels ,
temperature = temperature * ( steps_until_x0 / max_steps ) ,
* * sampling_kwargs ,
)
# retrieves unfiltered logits
unfiltered_sampled = super ( ) . sample (
logits = logits ,
prev_list = prev_list ,
quant_levels = quant_levels ,
temperature = 0.0 ,
* * sampling_kwargs ,
)
# update previous list of tokens
prev_list = resps_list
# sample with gumbelnoise
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
sampled_ids = [ gumbel_sample ( logits , temperature = temperature , dim = - 1 ) for logits in filtered_sampled . logits [ 0 ] ]
#sampled_ids = filtered_sampled[0]
# keep unmasked tokens
resps_list = [ torch . where ( masked , input_ids , resps ) for masked , input_ids , resps in zip ( is_masked , sampled_ids , resps_list ) ]
# update scores (conjugated to put the worst scores at the top)
2024-11-12 02:27:38 +00:00
scores = [ 1.0 - torch . tensor ( [ score for score in scores ] , device = device ) for scores in filtered_sampled . scores ]
2024-11-04 00:31:28 +00:00
2024-11-12 02:21:16 +00:00
if cfg . experimental and max_steps > 0 :
print ( timestep , steps_until_x0 , noise_p , resps_list , scores )
return resps_list
def forward_nar (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
2024-10-12 02:18:26 +00:00
2024-07-20 01:49:40 +00:00
disable_tqdm = False ,
2024-10-11 00:04:12 +00:00
use_lora = None ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2023-09-06 23:58:35 +00:00
) :
2024-11-10 18:19:48 +00:00
# deduce batch_size
2024-09-06 01:43:20 +00:00
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
2024-11-12 02:21:16 +00:00
# convert NAR specific args
sampling_kwargs = convert_kwargs ( sampling_kwargs , " nar_ " )
2024-11-12 02:27:38 +00:00
max_levels = sampling_kwargs . get ( " max_levels " , 0 )
cfg_strength = sampling_kwargs . get ( " cfg_strength " , 0.0 )
2024-11-10 18:19:48 +00:00
if max_levels == 0 :
max_levels = self . n_max_levels - 1
2023-09-06 23:58:35 +00:00
2024-11-12 02:21:16 +00:00
"""
2024-11-10 18:19:48 +00:00
sampling_layer_skip_variables = { } if sampling_layer_skip else None
2024-09-06 19:30:12 +00:00
2024-11-10 18:19:48 +00:00
if sampling_layer_skip :
if sampling_layer_skip_entropy_threshold > = 0 :
sampling_layer_skip_variables [ " entropy_threshold " ] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold > = 0 :
sampling_layer_skip_variables [ " varentropy_threshold " ] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer > = 0 :
sampling_layer_skip_variables [ " max_layer " ] = sampling_layer_skip_exit_layer
2024-11-12 02:21:16 +00:00
"""
2023-09-06 23:58:35 +00:00
2024-11-10 18:19:48 +00:00
# inference NAR level 0
if len_list is not None :
2024-11-12 02:21:16 +00:00
resps_list = self . forward_nar_masked (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
* * sampling_kwargs ,
)
2024-11-10 18:19:48 +00:00
# expand if given a raw 1D tensor
for i , resp in enumerate ( resps_list ) :
if resp . dim ( ) == 1 :
resps_list [ i ] = resp . unsqueeze ( - 1 )
prev_list = resps_list
2024-06-30 00:46:11 +00:00
2024-11-12 02:27:38 +00:00
null_text = [ torch . tensor ( [ 1 , 2 ] , device = device , dtype = torch . int16 ) for _ in range ( batch_size ) ]
null_prom = [ None for _ in range ( batch_size ) ]
2024-11-10 18:19:48 +00:00
for n in trange ( max_levels , desc = " NAR " , disable = disable_tqdm ) :
level = prev_list [ 0 ] . shape [ - 1 ]
if level > = max_levels + 1 : # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
2024-06-05 04:23:31 +00:00
2024-11-10 18:19:48 +00:00
if cfg . lora is not None :
enable_lora ( self , cfg . lora . active_level ( level ) if use_lora is None else use_lora )
2024-04-17 02:04:48 +00:00
2024-11-10 18:19:48 +00:00
quant_levels = [ level for _ in range ( batch_size ) ] # torch.full((len(text_list),), level)
2023-09-06 23:58:35 +00:00
2024-11-10 18:19:48 +00:00
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = prev_list ,
lang_list = lang_list ,
tone_list = tone_list ,
quant_levels = quant_levels ,
)
2024-11-03 15:58:29 +00:00
2024-11-10 18:19:48 +00:00
output = super ( ) . forward (
inputs = inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits , state = output . logits , output . state
2024-11-03 15:58:29 +00:00
2024-11-12 02:27:38 +00:00
if cfg_strength > 0 :
null_inputs = super ( ) . inputs (
text_list = null_text ,
proms_list = null_prom ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
quant_levels = quant_levels ,
)
null_output = super ( ) . forward (
inputs = null_inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len , logit , null_logit in zip ( len_list , output . logits , null_output . logits ) :
logit [ - seq_len : ] = null_logit [ - seq_len : ] + ( logit [ - seq_len : ] - null_logit [ - seq_len : ] ) * cfg_strength
2024-11-10 18:19:48 +00:00
sampled = super ( ) . sample (
logits = logits ,
prev_list = prev_list ,
quant_levels = quant_levels ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
)
resps_list = sampled [ 0 ]
prev_list = [ torch . cat ( [ rs , r . unsqueeze ( - 1 ) . to ( device = device ) ] , dim = - 1 ) for rs , r in zip ( prev_list , resps_list ) ]
return prev_list
def forward_ar (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
disable_tqdm = False ,
use_lora = None ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
) :
# deduce batch_size
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
if cfg . lora is not None :
enable_lora ( self , cfg . lora . active_level ( 0 ) if use_lora is None else use_lora )
2024-11-12 02:21:16 +00:00
# convert AR specific args
sampling_kwargs = convert_kwargs ( sampling_kwargs , " ar_ " )
temperature = sampling_kwargs . get ( " temperature " , 1.0 )
2024-11-12 02:27:38 +00:00
cfg_strength = sampling_kwargs . get ( " cfg_strength " , 0.0 )
2024-11-12 02:21:16 +00:00
min_temperature = sampling_kwargs . get ( " min_temperature " , - 1.0 )
max_duration = sampling_kwargs . get ( " max_duration " , 500 )
beam_width = sampling_kwargs . get ( " beam_width " , 0 )
entropix_sampling = sampling_kwargs . get ( " entropix_sampling " , False )
refine_on_stop = sampling_kwargs . get ( " refine_on_stop " , False )
input_prompt_prefix = sampling_kwargs . get ( " input_prompt_prefix " , False )
layer_skip = sampling_kwargs . get ( " layer_skip " , False )
prefix_silence = sampling_kwargs . get ( " prefix_silence " , 0.0 )
mirostat_tau = sampling_kwargs . get ( " mirostat_tau " , 0.0 )
mirostat_eta = sampling_kwargs . get ( " mirostat_eta " , 0.0 )
2024-11-10 18:19:48 +00:00
# inference len
if task_list is not None and task_list [ 0 ] == " len " :
sequence_list = [ torch . tensor ( [ 0 ] , device = device , dtype = torch . int16 ) for _ in range ( batch_size ) ]
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
stop_token = 10
task_list = [ " len " for _ in range ( batch_size ) ]
2024-11-12 02:21:16 +00:00
quant_levels = [ 0 for _ in range ( max ( batch_size , beam_width ) ) ]
2024-11-10 18:19:48 +00:00
for n in trange ( 10 , desc = " AR " , disable = disable_tqdm ) :
len_list = sequence_list
2023-09-06 23:58:35 +00:00
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
2023-09-13 02:28:07 +00:00
text_list = text_list ,
proms_list = proms_list ,
2024-11-10 18:19:48 +00:00
resps_list = resps_list ,
2023-10-12 02:21:50 +00:00
lang_list = lang_list ,
2024-04-16 00:54:32 +00:00
tone_list = tone_list ,
2024-11-10 18:19:48 +00:00
len_list = len_list ,
task_list = task_list ,
2024-06-05 04:23:31 +00:00
quant_levels = quant_levels ,
2024-04-17 02:04:48 +00:00
)
2024-10-06 03:53:53 +00:00
output = super ( ) . forward (
2024-04-17 02:04:48 +00:00
inputs = inputs ,
2023-09-13 02:28:07 +00:00
quant_levels = quant_levels ,
)
2024-11-10 18:19:48 +00:00
logits = output . logits
r = [ logit [ - 1 : ] . argmax ( dim = 1 ) for logit in logits ]
# sanitize
for i , token in enumerate ( r ) :
if token > 10 :
r [ i ] [ 0 ] = stop_token
# append tokens
for i , ri in enumerate ( r ) :
if stop_token in ri :
stopped [ i ] = True
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , ri . to ( device ) ] )
# stop token found
stopped | = r == stop_token
if stopped . all ( ) . item ( ) :
break
2023-09-06 23:58:35 +00:00
2024-11-10 18:19:48 +00:00
# convert tokens into int
return [ int ( " " . join ( [ str ( token . item ( ) ) for token in r if token != stop_token ] ) ) for r in sequence_list ]
2024-06-18 02:45:03 +00:00
2024-09-06 01:43:20 +00:00
# STT
2024-10-04 23:57:19 +00:00
start_slice = [ 0 for _ in range ( batch_size ) ]
2024-06-08 20:42:02 +00:00
sequence_list = [ torch . zeros ( 0 , device = device ) . to ( torch . int16 ) for _ in range ( batch_size ) ]
2023-09-06 23:58:35 +00:00
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-06-08 20:42:02 +00:00
2024-09-06 19:30:12 +00:00
audio_stop_token = self . stop_token
text_stop_token = 2
2023-09-06 23:58:35 +00:00
2024-06-29 03:44:00 +00:00
state = None
2023-09-18 23:55:41 +00:00
mirostat = [
2024-11-12 02:21:16 +00:00
{ " n " : 1024 , " tau " : mirostat_tau , " eta " : mirostat_eta , " max_surprise " : mirostat_eta * 2 , " error_surprise " : 0 , " running_total_surprise " : 0 }
] * batch_size if mirostat_tau > 0.0 else None
2023-09-06 23:58:35 +00:00
2024-11-12 02:21:16 +00:00
scores = [ 1.0 ] * beam_width
2024-10-23 01:13:54 +00:00
metrics = [ ]
2024-10-12 02:18:26 +00:00
2024-10-23 01:13:54 +00:00
"""
2024-11-02 16:49:05 +00:00
sampling_layer_skip_variables = { } if sampling_layer_skip else None
if sampling_layer_skip :
2024-11-04 00:31:28 +00:00
if sampling_layer_skip_entropy_threshold > = 0 :
sampling_layer_skip_variables [ " entropy_threshold " ] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold > = 0 :
sampling_layer_skip_variables [ " varentropy_threshold " ] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer > = 0 :
sampling_layer_skip_variables [ " max_layer " ] = sampling_layer_skip_exit_layer
2024-11-12 02:21:16 +00:00
"""
2024-11-02 16:49:05 +00:00
2024-09-06 01:43:20 +00:00
for i , sequence in enumerate ( sequence_list ) :
2024-10-04 23:57:19 +00:00
# add <bos> to text for STT
2024-09-06 01:43:20 +00:00
if task_list [ i ] in text_task :
2024-10-04 23:57:19 +00:00
start_slice [ i ] = 1
2024-09-06 01:43:20 +00:00
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , torch . tensor ( [ 1 ] , dtype = torch . int16 , device = device ) ] )
2024-10-04 23:57:19 +00:00
# treat input prompt as initial resp (by prefixing with the prompt instead)
elif input_prompt_prefix :
start_slice [ i ] = proms_list [ i ] . shape [ 0 ]
sequence_list [ i ] , proms_list [ i ] = proms_list [ i ] [ : , 0 ] , sequence_list [ i ]
2024-10-18 22:19:52 +00:00
elif prefix_silence > 0 :
sequence_list [ i ] = get_silence ( prefix_silence , device = sequence_list [ i ] . device )
sequence_list [ i ] = sequence_list [ i ] [ : , 0 ]
# start_slice[i] = sequence_list[i].shape[0]
2024-09-06 01:43:20 +00:00
2024-11-12 02:27:38 +00:00
null_text = [ torch . tensor ( [ 1 , 2 ] , device = device , dtype = torch . int16 ) for _ in range ( batch_size ) ]
null_prom = [ None for _ in range ( batch_size ) ]
2023-09-13 18:19:11 +00:00
# get next in sequence
2024-11-12 02:21:16 +00:00
for n in trange ( max_duration / / max ( 1 , self . causal_size ) , desc = " AR " , disable = disable_tqdm ) :
2024-10-05 03:30:47 +00:00
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
2024-09-06 19:30:12 +00:00
text_list = [ sequence_list [ i ] if task in text_task else text_list [ i ] for i , task in enumerate ( task_list ) ]
resps_list = [ sequence_list [ i ] if task not in text_task else resps_list [ i ] for i , task in enumerate ( task_list ) ]
2024-04-17 02:04:48 +00:00
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
2024-06-08 20:42:02 +00:00
len_list = len_list ,
task_list = task_list ,
2024-11-12 02:21:16 +00:00
quant_levels = [ 0 for _ in range ( max ( batch_size , beam_width ) ) ]
2024-04-17 02:04:48 +00:00
)
2024-10-05 03:30:47 +00:00
# to-do: find an elegant way to write this
2024-10-06 03:53:53 +00:00
output = super ( ) . forward (
inputs = inputs ,
state = state ,
2024-11-10 18:19:48 +00:00
#layer_skip_variables=sampling_layer_skip_variables,
2024-11-12 02:21:16 +00:00
output_attentions = entropix_sampling ,
2024-10-06 03:53:53 +00:00
)
2024-11-12 02:27:38 +00:00
if cfg_strength > 0 :
null_inputs = super ( ) . inputs (
text_list = null_text ,
proms_list = null_prom ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
quant_levels = quant_levels ,
)
null_output = super ( ) . forward (
inputs = null_inputs ,
quant_levels = quant_levels ,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len , logit , null_logit in zip ( len_list , output . logits , null_output . logits ) :
logit [ - seq_len : ] = null_logit [ - seq_len : ] + ( logit [ - seq_len : ] - null_logit [ - seq_len : ] ) * cfg_strength
2024-10-12 02:18:26 +00:00
logits , state = output . logits , output . state
2023-09-06 23:58:35 +00:00
2024-10-12 02:18:26 +00:00
sampled = super ( ) . sample (
2023-09-13 02:28:07 +00:00
logits = logits ,
2024-11-12 02:21:16 +00:00
prev_list = [ resps_list [ i ] if task not in text_task else text_list [ i ] for i , task in enumerate ( task_list ) ] ,
* * ( sampling_kwargs | { " attentions " : output . attentions if entropix_sampling else None } ) ,
2023-09-13 02:28:07 +00:00
)
2024-10-12 02:18:26 +00:00
r = sampled [ 0 ]
2024-11-02 16:49:05 +00:00
if cfg . experimental :
if sampled . entropy :
metrics . append ( sampled . entropy )
elif sampled . scores :
2024-11-08 03:19:14 +00:00
#metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
metrics . append ( [ { " p " : p [ 0 ] } for p in sampled . scores ] )
2024-10-12 02:18:26 +00:00
2023-09-18 23:55:41 +00:00
if mirostat is not None :
2024-10-12 02:18:26 +00:00
mirostat = sampled . scores
2024-11-12 02:21:16 +00:00
elif beam_width > 0 :
2023-09-13 18:19:11 +00:00
# expand tuple
2024-10-23 05:03:35 +00:00
s = sampled . scores
2023-09-13 18:19:11 +00:00
# first step, expand batch
if batch_size == 1 :
2024-11-12 02:21:16 +00:00
batch_size = beam_width
text_list = text_list * beam_width
proms_list = proms_list * beam_width
sequence_list = sequence_list * beam_width
task_list = task_list * beam_width
start_slice = start_slice * beam_width
2023-09-13 18:19:11 +00:00
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-10-23 05:03:35 +00:00
scores = [ scores [ i ] + score for i , score in enumerate ( s ) ]
2023-09-13 02:28:07 +00:00
2023-09-06 23:58:35 +00:00
# append tokens
for i , ri in enumerate ( r ) :
2024-09-06 19:30:12 +00:00
task = task_list [ i ]
stop_token = audio_stop_token if task not in text_task else text_stop_token
2024-06-08 20:42:02 +00:00
if stop_token in ri :
2023-09-06 23:58:35 +00:00
stopped [ i ] = True
2023-10-11 17:25:31 +00:00
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , ri . to ( device ) ] )
2023-09-06 23:58:35 +00:00
# stop token found
2024-09-06 19:30:12 +00:00
# stopped |= r == stop_token
2023-09-06 23:58:35 +00:00
if stopped . all ( ) . item ( ) :
break
2024-11-04 00:31:28 +00:00
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
2024-11-12 02:21:16 +00:00
"""
2024-10-23 01:13:54 +00:00
if metrics :
from . . plot import plot_sample_metrics
2024-11-02 16:49:05 +00:00
filename = " metrics "
2024-11-12 02:21:16 +00:00
if entropix_sampling :
filename + = f ' [entropix_sampling] '
2024-11-02 16:49:05 +00:00
if sampling_layer_skip_exit_layer > = 0 :
filename + = f ' [ { sampling_layer_skip_exit_layer + 1 } ] '
plot_sample_metrics ( metrics , filename = f ' { filename } .png ' )
2024-11-12 02:21:16 +00:00
"""
2024-10-12 02:18:26 +00:00
2023-09-13 18:19:11 +00:00
# pick the best scoring candidate
# desu this is always going to be candidate 0
2024-11-12 02:21:16 +00:00
if beam_width :
2024-10-23 03:06:22 +00:00
sequence_list = sequence_list [ : 1 ]
task_list = task_list [ : 1 ]
2023-09-13 02:28:07 +00:00
2024-09-06 19:30:12 +00:00
# remove stop token
sequence_list = [ self . _prune ( r , audio_stop_token if task_list [ i ] not in text_task else text_stop_token ) for i , r in enumerate ( sequence_list ) ]
# remove <bos>
2024-10-04 23:57:19 +00:00
sequence_list = [ sequence_list [ i ] [ start_slice [ i ] : ] for i , task in enumerate ( task_list ) ]
2024-10-12 15:41:35 +00:00
2024-11-12 02:21:16 +00:00
if refine_on_stop :
2024-11-04 00:31:28 +00:00
# get how much we need to slice from the end
slice_lengths = [ sequence . shape [ - 1 ] for sequence in sequence_list ]
# -1 for the stop token
logits = [ logit [ - length - 1 : - 1 ] for logit , length in zip ( logits , slice_lengths ) ]
# greedy sample from the sequence
refined_list = [ logit . argmax ( dim = - 1 ) for logit in logits ]
# to-do: compare scores
# set the "refined" list as the output
sequence_list = refined_list
2024-06-08 20:42:02 +00:00
return sequence_list
2023-09-06 23:58:35 +00:00
2024-11-10 18:19:48 +00:00
def forward (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
training : bool | int | None = None ,
disable_tqdm = False ,
use_lora = None ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
) :
# deduce batch_size
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
# generate task list if not provided
if task_list is None :
task_list = [ default_task for _ in range ( batch_size ) ]
# implicitly set for training
if training is None and text_list is not None and resps_list is not None :
n_levels_set = { r . shape [ - 1 ] for r in resps_list }
n_levels = next ( iter ( n_levels_set ) )
training = n_levels == self . n_resp_levels
# is training
if training :
return self . forward_train (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
)
# is NAR
if ( len_list is not None or resps_list is not None ) and text_list is not None :
return self . forward_nar (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
)
# is AR
return self . forward_ar (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
2024-11-12 02:21:16 +00:00
* * sampling_kwargs ,
2024-11-10 18:19:48 +00:00
)
2023-09-06 23:58:35 +00:00
def example_usage ( ) :
2024-11-10 18:19:48 +00:00
cfg . device = " cuda "
2024-08-10 02:15:01 +00:00
cfg . trainer . backend = " local "
2024-05-25 16:07:52 +00:00
if cfg . audio_backend == " dac " :
2024-07-04 20:58:08 +00:00
cfg . sample_rate = 44_100
2023-09-06 23:58:35 +00:00
2024-04-21 19:58:04 +00:00
from functools import partial
2023-09-06 23:58:35 +00:00
from einops import repeat
2024-04-21 19:58:04 +00:00
from tqdm import tqdm
2023-09-06 23:58:35 +00:00
2024-07-18 23:46:45 +00:00
from . . emb . qnt import decode_to_file , unload_model , trim_random , repeat_extend_audio , concat_audio , merge_audio
2024-07-25 21:50:47 +00:00
from . . engines import Engine , Engines
2023-09-07 22:08:38 +00:00
from . . utils import wrapper as ml
2024-09-06 04:21:18 +00:00
from . . utils import setup_logging
2024-04-21 19:58:04 +00:00
import numpy as np
2024-04-05 00:11:49 +00:00
import re
2023-09-06 23:58:35 +00:00
2024-09-06 04:21:18 +00:00
setup_logging ( )
2023-09-06 23:58:35 +00:00
2024-11-10 04:57:34 +00:00
def load_artifact ( path ) :
artifact = np . load ( path , allow_pickle = True ) [ ( ) ]
2024-11-10 18:19:48 +00:00
text = torch . tensor ( cfg . tokenizer . encode ( artifact [ " metadata " ] [ " phonemes " ] ) ) . to ( dtype = torch . uint8 , device = cfg . device )
audio = torch . from_numpy ( artifact [ " codes " ] . astype ( np . int16 ) ) [ 0 , : , : ] . t ( ) . to ( dtype = torch . int16 , device = cfg . device )
2024-11-10 04:57:34 +00:00
return text , audio
text , audio = load_artifact ( f " ./data/qnt. { ' dac ' if cfg . audio_backend == ' dac ' else ' enc ' } " )
2024-11-10 18:19:48 +00:00
batch_size = cfg . hyperparameters . batch_size
cfg . model . experimental . masking_train_p = 0.5
2024-11-10 04:57:34 +00:00
2024-11-10 18:19:48 +00:00
text_list = [ text ] * batch_size
proms_list = [ audio [ : cfg . dataset . frames_per_second , : ] ] * batch_size
resps_list = [ audio ] * batch_size
2024-07-18 23:46:45 +00:00
2023-09-06 23:58:35 +00:00
kwargs = {
2024-06-06 00:50:06 +00:00
' n_text_tokens ' : 256 ,
' n_audio_tokens ' : 1024 ,
2024-08-05 03:03:22 +00:00
' d_model ' : 1024 , # 256, # 1024, # 1536
2023-12-21 00:45:58 +00:00
' n_heads ' : 16 , # 4, # 16, # 24
2024-06-05 03:41:22 +00:00
' n_layers ' : 12 , # 32
2024-08-05 01:25:06 +00:00
' n_experts ' : 1 if not cfg . model else cfg . model . experts ,
2024-04-09 01:14:51 +00:00
2024-05-19 21:40:14 +00:00
' p_dropout ' : 0.1 ,
2024-05-11 22:14:05 +00:00
2024-05-03 01:08:59 +00:00
' l_padding ' : 8 if cfg . optimizations . fp8 else 0 ,
2024-04-30 03:14:01 +00:00
' config ' : cfg . model
2023-12-21 00:45:58 +00:00
}
2023-09-06 23:58:35 +00:00
2024-07-18 23:46:45 +00:00
bos_id , space_id , eos_id = cfg . tokenizer . encode ( " " )
2024-11-10 18:19:48 +00:00
available_tasks = [ " tts-ar " , " tts-nar " ]
2024-07-18 23:46:45 +00:00
2024-11-10 18:19:48 +00:00
model = AR_NAR ( * * kwargs ) . to ( cfg . device )
steps = 500 / / batch_size
2024-04-09 01:14:51 +00:00
2024-06-09 16:22:52 +00:00
optimizer = cfg . hyperparameters . optimizer . lower ( ) if cfg . yaml_path is not None else " prodigy "
scheduler = cfg . hyperparameters . scheduler . lower ( ) if cfg . yaml_path is not None else " "
learning_rate = cfg . hyperparameters . learning_rate if cfg . yaml_path is not None else None
2024-05-10 01:28:20 +00:00
if cfg . optimizations . dadaptation :
# do not combine the two
if scheduler == " schedulefree " :
scheduler = " "
2023-09-07 14:14:03 +00:00
2024-05-10 01:28:20 +00:00
learning_rate = 1.0
if optimizer == " prodigy " :
if learning_rate is None :
learning_rate = 1.0
optimizer = ml . Prodigy
elif optimizer == " adagrad " :
if learning_rate is None :
learning_rate = 1.0e-2
optimizer = ml . Adagrad
elif optimizer == " adamw " :
if learning_rate is None :
learning_rate = 1.0e-4
optimizer = ml . AdamW
elif optimizer == " sdg " :
if learning_rate is None :
learning_rate = 1.0e-4
optimizer = ml . SGD
else :
raise ValueError ( f " Unrecognized optimizer: { optimizer } " )
2024-08-29 18:27:16 +00:00
_logger . info ( f " Optimizer: { optimizer } \t Learning rate: { learning_rate } " )
2024-05-10 01:28:20 +00:00
optimizer = optimizer ( model . parameters ( ) , lr = learning_rate )
if scheduler == " schedulefree " :
if isinstance ( optimizer , ml . AdamW ) :
scheduler = ml . schedulefree . AdamWScheduleFree
elif isinstance ( optimizer , ml . SGD ) :
scheduler = ml . schedulefree . SGDScheduleFree
else :
scheduler = None
if scheduler is not None :
2024-08-29 18:27:16 +00:00
_logger . info ( f " Scheduler: { scheduler } " )
2024-05-10 01:28:20 +00:00
optimizer = scheduler ( model . parameters ( ) , lr = learning_rate )
if cfg . optimizations . replace and cfg . optimizations . linear :
model = ml . replace_linear ( model )
if cfg . optimizations . replace and cfg . optimizations . embedding :
model = ml . replace_embedding ( model )
2024-08-02 01:12:06 +00:00
"""
cfg . optimizations . model_offloading = {
" devices " : [ " cuda:0 " , " cpu " ] ,
2024-08-02 01:56:28 +00:00
# "limits": [ 0.9, -1 ],
2024-08-02 03:43:39 +00:00
" assign " : [ [ f ' layers. { i } . ' for i in range ( 0 , 10 ) ] , [ f ' layers. { i } . ' for i in range ( 11 , 12 ) ] + [ " model.norm " ] ] ,
# "limits": [ 256 * (1024 ** 2), -1 ]
2024-08-02 01:12:06 +00:00
}
"""
2024-05-10 01:28:20 +00:00
engine = Engine ( model = model , optimizer = optimizer )
2024-07-25 21:50:47 +00:00
engines = Engines ( { " ar+nar " : engine } )
engines . setup ( )
2024-08-02 01:56:28 +00:00
2024-08-03 03:25:49 +00:00
"""
2024-08-02 01:56:28 +00:00
if cfg . optimizations . model_offloading :
model = ml . offload_model ( model , policy = cfg . optimizations . model_offloading )
2024-08-03 03:25:49 +00:00
"""
2024-07-25 21:50:47 +00:00
2024-06-07 01:51:31 +00:00
"""
2023-09-21 00:10:59 +00:00
torch . save ( {
' module ' : model . state_dict ( )
2024-06-04 05:07:00 +00:00
} , f " ./data/ { cfg . model . arch_type } .pth " )
2024-06-07 01:51:31 +00:00
"""
2023-09-21 00:10:59 +00:00
2024-08-29 18:27:16 +00:00
_logger . info ( f " AR+NAR ( { cfg . model . arch_type } , { cfg . audio_backend } ) parameter count: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) } " )
2024-04-09 01:14:51 +00:00
2024-07-18 23:46:45 +00:00
@torch.no_grad ( )
2024-09-06 01:43:20 +00:00
def sample_data ( t = None ) :
2024-09-06 19:30:12 +00:00
if isinstance ( t , list ) :
tasks = t
2024-11-10 18:19:48 +00:00
texts = [ text_list [ 0 ] . to ( cfg . device ) if task not in text_task else None for i , task in enumerate ( tasks ) ]
proms = [ proms_list [ 0 ] . to ( cfg . device ) if task not in text_task else [ " stt " ] for i , task in enumerate ( tasks ) ]
resps = [ None if task not in text_task else resps_list [ 0 ] . to ( cfg . device ) for i , task in enumerate ( tasks ) ]
2024-09-06 19:30:12 +00:00
return texts , proms , resps , tasks
2024-07-18 23:46:45 +00:00
texts = [ ]
proms = [ ]
resps = [ ]
2024-09-06 01:43:20 +00:00
tasks = [ ]
2024-07-18 23:46:45 +00:00
for i in range ( batch_size ) :
2024-09-06 01:43:20 +00:00
task = random . choice ( available_tasks ) if t is None else t
2024-07-18 23:46:45 +00:00
2024-11-10 18:19:48 +00:00
text = text_list [ i ] . to ( cfg . device )
prom = proms_list [ i ] . to ( cfg . device )
resp = resps_list [ i ] . to ( cfg . device )
2024-07-18 23:46:45 +00:00
# do nothing
2024-11-10 18:19:48 +00:00
if task == " stt " :
prom = [ task ]
else :
2024-11-10 18:48:41 +00:00
task = " tts " if random . random ( ) > 0.1 else " len "
2024-09-06 19:30:12 +00:00
texts . append ( text )
proms . append ( prom )
resps . append ( resp )
2024-09-06 01:43:20 +00:00
tasks . append ( task )
2024-07-18 23:46:45 +00:00
2024-09-06 01:43:20 +00:00
return texts , proms , resps , tasks
2024-04-18 01:39:35 +00:00
2024-07-18 23:46:45 +00:00
@torch.inference_mode ( )
2024-09-06 01:43:20 +00:00
def sample ( name , steps = 500 , task = None ) :
2023-09-06 23:58:35 +00:00
engine . eval ( )
2024-07-18 23:46:45 +00:00
2024-11-10 18:19:48 +00:00
text_list , proms_list , resp_list , task_list = sample_data ( task )
2024-06-05 04:48:51 +00:00
2024-11-10 18:19:48 +00:00
if task == " tts-nar " :
2024-11-12 02:21:16 +00:00
len_list = engine ( text_list , proms_list , task_list = [ " len " ] , max_steps = 5 , temperature = 0.0 )
2024-11-10 18:19:48 +00:00
len_list = [ resp_list [ 0 ] . shape [ 0 ] for l in len_list ]
2024-11-12 02:21:16 +00:00
resps_list = engine ( text_list , proms_list , len_list = len_list , temperature = 0.0 )
2024-09-06 01:43:20 +00:00
else :
2024-11-12 02:21:16 +00:00
resps_list = engine ( text_list , proms_list , task_list = [ " tts " ] , max_duration = steps , temperature = 1.0 )
resps_list = engine ( text_list , proms_list , resps_list = resps_list , temperature = 0.0 )
2024-09-06 19:30:12 +00:00
2024-11-10 18:19:48 +00:00
for i , o in enumerate ( resps_list ) :
_ = decode_to_file ( o . to ( dtype = torch . int32 ) , f " data/ { cfg . model . arch_type } . { cfg . audio_backend } . { i } . { name } . { task } .wav " , device = cfg . device )
2023-09-06 23:58:35 +00:00
2023-09-07 01:33:16 +00:00
unload_model ( )
2023-09-06 23:58:35 +00:00
def train ( ) :
engine . train ( )
2023-09-07 22:08:38 +00:00
t = trange ( steps )
2023-09-06 23:58:35 +00:00
for i in t :
2024-09-06 01:43:20 +00:00
texts , proms , resps , tasks = sample_data ( )
2024-07-18 23:46:45 +00:00
2023-09-06 23:58:35 +00:00
stats = { " step " : i }
2024-11-10 18:19:48 +00:00
stats | = engine . traverse ( text_list = texts , proms_list = proms , resps_list = resps , task_list = tasks , training = True )
2024-03-02 02:18:43 +00:00
stats | = { " grad_norm " : engine . get_global_grad_norm ( ) }
2023-09-06 23:58:35 +00:00
tqdm . write ( f " { stats } " )
2024-06-07 01:51:31 +00:00
"""
2023-12-23 01:27:36 +00:00
torch . save ( {
' module ' : model . state_dict ( )
2024-06-04 05:07:00 +00:00
} , f " ./data/ { cfg . model . arch_type } .pth " )
2024-06-07 01:51:31 +00:00
"""
2023-12-23 01:27:36 +00:00
2024-06-05 04:48:51 +00:00
#sample("init", 5)
2023-09-06 23:58:35 +00:00
train ( )
2024-08-04 03:10:21 +00:00
"""
if cfg . optimizations . compile :
model = ml . compile_model ( model , backend = cfg . optimizations . compile )
"""
2024-07-18 23:46:45 +00:00
2024-09-06 01:43:20 +00:00
for task in available_tasks :
2024-07-18 23:46:45 +00:00
sample ( " final " , task = task )
2023-09-06 23:58:35 +00:00
2024-07-25 21:50:47 +00:00
engines . quit ( )
2023-09-06 23:58:35 +00:00
if __name__ == " __main__ " :
2024-06-06 01:30:43 +00:00
example_usage ( )