2024-08-05 00:56:21 +00:00
"""
# an AR model that (should) handle:
* handling all RVQ levels , but does it in an autoregressive manner
It ' s in a mess of a state, because I want this to be an interleaved model, but it just seems better to use the vall_e.models.experimental model.
"""
from . base import Base , list_to_tensor , Categorical
from . . config import cfg
import torch
from torch . nn . utils . rnn import pad_sequence
import random
import math
from einops import rearrange
from torch import Tensor
from tqdm import trange
2024-08-29 18:27:16 +00:00
import logging
_logger = logging . getLogger ( __name__ )
2024-08-05 00:56:21 +00:00
from . . emb . qnt import trim , encode_as_embedding
from . lora import enable_lora
def clamp ( n , lo , hi ) :
return max ( lo , min ( n , hi ) )
class AR ( Base ) :
def forward (
self ,
text_list : list [ Tensor ] ,
proms_list : list [ Tensor ] ,
resps_list : list [ Tensor ] | None = None ,
task_list : list [ Tensor ] | None = None ,
lang_list : list [ Tensor ] | None = None ,
tone_list : list [ Tensor ] | None = None ,
len_list : list [ Tensor ] | None = None ,
2024-11-02 02:30:06 +00:00
training : bool | int | None = None ,
2024-08-05 00:56:21 +00:00
max_steps : int = 1000 ,
max_levels : int = 0 ,
2024-10-18 22:19:52 +00:00
input_prompt_prefix : bool = False ,
prefix_silence : float = 1.0 ,
2024-08-05 00:56:21 +00:00
sampling_temperature : float = 1.0 ,
sampling_min_temperature : float = - 1.0 ,
sampling_top_k : int = - 100 ,
sampling_top_p : float = 1.0 ,
2024-10-12 03:36:06 +00:00
sampling_min_p : float = 0.0 ,
2024-08-05 00:56:21 +00:00
sampling_repetition_penalty : float = 1.0 ,
sampling_repetition_penalty_decay : float = 0.0 ,
sampling_length_penalty : float = 0.0 ,
sampling_beam_width : int = 0 ,
sampling_mirostat_tau : float = 0.0 ,
sampling_mirostat_eta : float = 0.1 ,
sampling_dry_multiplier = 0.0 ,
sampling_dry_base = 1.75 ,
sampling_dry_allowed_length = 2 ,
2024-11-02 02:30:06 +00:00
sampling_entropix = False ,
2024-11-07 01:17:12 +00:00
2024-11-02 02:30:06 +00:00
sampling_layer_skip : bool = False ,
sampling_layer_skip_exit_layer : int = - 1 ,
2024-11-07 01:17:12 +00:00
sampling_layer_skip_entropy_threshold : float = - 1 ,
sampling_layer_skip_varentropy_threshold : float = - 1 ,
sampling_refine_on_stop : bool = False ,
2024-08-05 00:56:21 +00:00
disable_tqdm = False ,
2024-10-11 00:04:12 +00:00
use_lora = None ,
2024-08-05 00:56:21 +00:00
) :
2024-11-07 01:17:12 +00:00
text_task = [ " stt " ]
if text_list is not None :
default_task = " tts "
device = text_list [ 0 ] . device
batch_size = len ( text_list )
else :
default_task = " stt "
device = resps_list [ 0 ] . device
batch_size = len ( resps_list )
2024-08-05 00:56:21 +00:00
# generate task list if not provided
if task_list is None :
2024-11-07 01:17:12 +00:00
task_list = [ default_task for _ in range ( batch_size ) ]
has_none = resps_list is None or text_list is None
if not has_none :
for i , task in enumerate ( task_list ) :
if resps_list [ i ] is None or text_list [ i ] is None :
has_none = True
break
2024-08-05 00:56:21 +00:00
# is training or NAR
2024-11-07 01:17:12 +00:00
if not has_none :
2024-08-05 00:56:21 +00:00
n_levels_set = { r . shape [ - 1 ] for r in resps_list }
n_levels = next ( iter ( n_levels_set ) )
2024-11-07 01:17:12 +00:00
# implicit
2024-08-05 00:56:21 +00:00
if training is None :
2024-11-07 01:17:12 +00:00
training = 0 if n_levels == self . n_resp_levels else None
2024-08-05 00:56:21 +00:00
# is training
2024-11-07 01:17:12 +00:00
if training is not None :
2024-08-05 00:56:21 +00:00
# specifies how to sample probabilities of which RVQ levels to train against
2024-10-17 22:06:48 +00:00
rvq_levels_p = self . config . experimental . rvq_levels_p if self . config is not None else " equal "
2024-08-05 00:56:21 +00:00
# determines which RVQ level to target per batch
quant_level_range = self . config . experimental . rvq_level_range if self . config is not None and self . config . experimental . rvq_level_range else [ 0 if self . causal else 1 , self . n_resp_levels - 1 ]
# rate to perform token dropout errors
token_dropout_error = self . config . experimental . token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self . config . experimental . token_dropout_rvq_levels
# implicitly set it to all levels
if not token_dropout_rvq_levels :
token_dropout_rvq_levels = [ 0 , self . resp_levels - 1 ]
# allow passing a specific distribution of RVQ levels
2024-10-17 22:06:48 +00:00
rvq_levels_p = rvq_levels_p if isinstance ( rvq_levels_p , list ) else [ ]
if not rvq_levels_p :
2024-08-05 00:56:21 +00:00
lo , hi = quant_level_range [ 0 ] , quant_level_range [ 1 ] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
2024-10-17 22:06:48 +00:00
if rvq_levels_p == " equal " :
rvq_levels_p = [ i for i in range ( lo , hi ) ]
2024-08-05 00:56:21 +00:00
else :
# yuck
2024-10-17 22:06:48 +00:00
rvq_levels_p = sum ( [ [ i for _ in range ( hi - i ) ] for i in range ( lo , hi ) ] , [ ] )
2024-08-05 00:56:21 +00:00
# input RVQ levels
2024-11-07 01:17:12 +00:00
quant_levels = [ random . choice ( rvq_levels_p ) for i in range ( batch_size ) ]
for i , task in enumerate ( task_list ) :
if task in text_task :
quant_levels [ i ] = 0 # self.n_resp_levels - 1
# trim resps to only contain all levels below the target level
resps_list = [ r if t in text_task else r [ . . . , : l + 1 ] for r , l , t in zip ( resps_list , quant_levels , task_list ) ]
2024-08-05 00:56:21 +00:00
# tensor to cat for RVQ level 0
2024-11-07 01:17:12 +00:00
text_stop_sequence = torch . tensor ( [ [ 2 ] * 1 ] , device = device , dtype = torch . int16 )
audio_stop_sequence = torch . tensor ( [ [ self . stop_token ] * 1 ] , device = device , dtype = torch . int16 )
2024-08-05 00:56:21 +00:00
# I hate python's value/reference semantics so much
2024-11-07 01:17:12 +00:00
for i , quant_level , resps , proms , task in zip ( range ( batch_size ) , quant_levels , resps_list , proms_list , task_list ) :
2024-08-05 00:56:21 +00:00
# cap quant_level if it exceeds its corresponding resp/prom
if quant_level > = resps . shape [ - 1 ] :
quant_levels [ i ] = resps . shape [ - 1 ] - 1
# proms could be a Tensor, list[Tensor], or None
if isinstance ( proms , torch . Tensor ) :
if quant_level > = proms . shape [ - 1 ] :
quant_levels [ i ] = proms . shape [ - 1 ] - 1
elif isinstance ( proms , list ) :
for j , prom in enumerate ( proms ) :
if not isinstance ( prom , torch . Tensor ) :
continue
2024-11-07 01:17:12 +00:00
if quant_level > = prom . shape [ - 1 ] :
quant_levels [ i ] = prom . shape [ - 1 ] - 1
2024-08-05 00:56:21 +00:00
# apply token dropout error compensation
if token_dropout_error > 0 and ( token_dropout_rvq_levels [ 0 ] < = quant_level and quant_level < = token_dropout_rvq_levels [ 1 ] ) :
steps = resps . shape [ 0 ]
for l in range ( quant_level ) :
for t in range ( steps ) :
token = resps [ t , l ] . item ( )
if random . random ( ) < token_dropout_error :
offset = 1 * ( 1 if random . random ( ) < 0.5 else - 1 )
resps_list [ i ] [ t , l ] = clamp ( token + offset , 1 , 1022 ) # +- 1
# only apply stop token for RVQ level 0
2024-11-07 01:17:12 +00:00
if quant_level < = 0 :
# append stop tokens for AR
if task in text_task :
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
. . .
else :
resps_list [ i ] = torch . cat ( [ resps , audio_stop_sequence ] )
2024-08-05 00:56:21 +00:00
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
task_list = task_list ,
2024-11-07 01:17:12 +00:00
quant_levels = quant_levels ,
2024-08-05 00:56:21 +00:00
)
return super ( ) . forward (
inputs = inputs ,
2024-11-07 01:17:12 +00:00
quant_levels = quant_levels , # could technically just grab this from the above inputs since they're included as an RVQ level token
2024-08-05 00:56:21 +00:00
)
# is AR
if cfg . lora is not None :
2024-10-11 00:04:12 +00:00
enable_lora ( self , cfg . lora . active_level ( 0 ) if use_lora is None else use_lora )
2024-08-05 00:56:21 +00:00
2024-11-07 01:17:12 +00:00
# STT
start_slice = [ 0 for _ in range ( batch_size ) ]
2024-08-05 00:56:21 +00:00
sequence_list = [ torch . zeros ( 0 , device = device ) . to ( torch . int16 ) for _ in range ( batch_size ) ]
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-11-07 01:17:12 +00:00
audio_stop_token = self . stop_token
text_stop_token = 2
2024-08-05 00:56:21 +00:00
state = None
mirostat = [
{ " n " : 1024 , " tau " : sampling_mirostat_tau , " eta " : sampling_mirostat_eta , " max_surprise " : sampling_mirostat_eta * 2 , " error_surprise " : 0 , " running_total_surprise " : 0 }
] * batch_size if sampling_mirostat_tau > 0.0 else None
scores = [ 1.0 ] * sampling_beam_width
2024-11-07 01:17:12 +00:00
metrics = [ ]
# ick
"""
low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
low_temperature_range = cfg . dataset . frames_per_second * 5
original_sampling_temperature = sampling_temperature
original_sampling_repetition_penalty = sampling_repetition_penalty
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
"""
sampling_layer_skip_variables = { } if sampling_layer_skip else None
if sampling_layer_skip :
if sampling_layer_skip_entropy_threshold > = 0 :
sampling_layer_skip_variables [ " entropy_threshold " ] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold > = 0 :
sampling_layer_skip_variables [ " varentropy_threshold " ] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer > = 0 :
sampling_layer_skip_variables [ " max_layer " ] = sampling_layer_skip_exit_layer
for i , sequence in enumerate ( sequence_list ) :
# add <bos> to text for STT
if task_list [ i ] in text_task :
start_slice [ i ] = 1
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , torch . tensor ( [ 1 ] , dtype = torch . int16 , device = device ) ] )
# treat input prompt as initial resp (by prefixing with the prompt instead)
elif input_prompt_prefix :
start_slice [ i ] = proms_list [ i ] . shape [ 0 ]
sequence_list [ i ] , proms_list [ i ] = proms_list [ i ] [ : , 0 ] , sequence_list [ i ]
elif prefix_silence > 0 :
sequence_list [ i ] = get_silence ( prefix_silence , device = sequence_list [ i ] . device )
sequence_list [ i ] = sequence_list [ i ] [ : , 0 ]
# start_slice[i] = sequence_list[i].shape[0]
2024-08-05 00:56:21 +00:00
# get next in sequence
for n in trange ( max_steps / / max ( 1 , self . causal_size ) , desc = " AR " , disable = disable_tqdm ) :
2024-11-07 01:17:12 +00:00
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
text_list = [ sequence_list [ i ] if task in text_task else text_list [ i ] for i , task in enumerate ( task_list ) ]
resps_list = [ sequence_list [ i ] if task not in text_task else resps_list [ i ] for i , task in enumerate ( task_list ) ]
# greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
# to-do: tune these values, maybe have it factor based on confidence scores or something
"""
if low_temperature :
enabled = n < low_temperature_range
sampling_repetition_penalty = 1.125 if enabled else 1.25
#sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
#sampling_temperature = original_sampling_temperature if enabled else 1.0
"""
2024-08-05 00:56:21 +00:00
inputs = self . inputs (
text_list = text_list ,
proms_list = proms_list ,
resps_list = resps_list ,
lang_list = lang_list ,
tone_list = tone_list ,
len_list = len_list ,
task_list = task_list ,
quant_levels = [ 0 for _ in range ( max ( batch_size , sampling_beam_width ) ) ]
)
2024-11-07 01:17:12 +00:00
# to-do: find an elegant way to write this
2024-10-12 02:18:26 +00:00
output = super ( ) . forward (
inputs = inputs ,
state = state ,
2024-11-07 01:17:12 +00:00
layer_skip_variables = sampling_layer_skip_variables ,
output_attentions = sampling_entropix ,
2024-10-12 02:18:26 +00:00
)
logits , state = output . logits , output . state
2024-08-05 00:56:21 +00:00
2024-10-12 02:18:26 +00:00
sampled = super ( ) . sample (
2024-08-05 00:56:21 +00:00
logits = logits ,
2024-11-07 01:17:12 +00:00
prev_list = None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list [ i ] if task not in text_task else text_list [ i ] for i , task in enumerate ( task_list ) ] ,
2024-08-05 00:56:21 +00:00
temperature = sampling_temperature ,
min_temperature = sampling_min_temperature ,
top_p = sampling_top_p ,
top_k = sampling_top_k ,
2024-10-12 03:36:06 +00:00
min_p = sampling_min_p ,
2024-08-05 00:56:21 +00:00
repetition_penalty = sampling_repetition_penalty ,
repetition_penalty_decay = sampling_repetition_penalty_decay ,
length_penalty = sampling_length_penalty ,
beam_width = sampling_beam_width ,
mirostat = mirostat ,
dry_multiplier = sampling_dry_multiplier ,
dry_base = sampling_dry_base ,
dry_allowed_length = sampling_dry_allowed_length ,
2024-11-07 01:17:12 +00:00
attentions = output . attentions if sampling_entropix else None ,
2024-08-05 00:56:21 +00:00
)
2024-10-12 02:18:26 +00:00
r = sampled [ 0 ]
2024-11-07 01:17:12 +00:00
if cfg . experimental :
if sampled . entropy :
metrics . append ( sampled . entropy )
elif sampled . scores :
metrics . append ( [ { " p " : p [ 0 ] , " exited_layer " : output . exited_layer } for p in sampled . scores ] )
2024-08-05 00:56:21 +00:00
if mirostat is not None :
2024-10-12 02:18:26 +00:00
mirostat = sampled . scores
2024-08-05 00:56:21 +00:00
elif sampling_beam_width > 0 :
# expand tuple
2024-11-07 01:17:12 +00:00
s = sampled . scores
2024-08-05 00:56:21 +00:00
# first step, expand batch
if batch_size == 1 :
batch_size = sampling_beam_width
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
2024-11-07 01:17:12 +00:00
task_list = task_list * sampling_beam_width
start_slice = start_slice * sampling_beam_width
2024-08-05 00:56:21 +00:00
stopped = torch . zeros ( batch_size , device = device ) . bool ( )
2024-11-07 01:17:12 +00:00
scores = [ scores [ i ] + score for i , score in enumerate ( s ) ]
2024-08-05 00:56:21 +00:00
# append tokens
for i , ri in enumerate ( r ) :
2024-11-07 01:17:12 +00:00
task = task_list [ i ]
stop_token = audio_stop_token if task not in text_task else text_stop_token
2024-08-05 00:56:21 +00:00
if stop_token in ri :
stopped [ i ] = True
sequence_list [ i ] = torch . cat ( [ sequence_list [ i ] , ri . to ( device ) ] )
# stop token found
2024-11-07 01:17:12 +00:00
# stopped |= r == stop_token
2024-08-05 00:56:21 +00:00
if stopped . all ( ) . item ( ) :
break
2024-11-07 01:17:12 +00:00
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
if metrics :
from . . plot import plot_sample_metrics
filename = " metrics "
if sampling_entropix :
filename + = f ' [entropix] '
if sampling_layer_skip_exit_layer > = 0 :
filename + = f ' [ { sampling_layer_skip_exit_layer + 1 } ] '
plot_sample_metrics ( metrics , filename = f ' { filename } .png ' )
2024-08-05 00:56:21 +00:00
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width :
2024-11-07 01:17:12 +00:00
sequence_list = sequence_list [ : 1 ]
task_list = task_list [ : 1 ]
# remove stop token
sequence_list = [ self . _prune ( r , audio_stop_token if task_list [ i ] not in text_task else text_stop_token ) for i , r in enumerate ( sequence_list ) ]
# remove <bos>
sequence_list = [ sequence_list [ i ] [ start_slice [ i ] : ] for i , task in enumerate ( task_list ) ]
if sampling_refine_on_stop :
# get how much we need to slice from the end
slice_lengths = [ sequence . shape [ - 1 ] for sequence in sequence_list ]
# -1 for the stop token
logits = [ logit [ - length - 1 : - 1 ] for logit , length in zip ( logits , slice_lengths ) ]
# greedy sample from the sequence
refined_list = [ logit . argmax ( dim = - 1 ) for logit in logits ]
# to-do: compare scores
# set the "refined" list as the output
sequence_list = refined_list
2024-08-05 00:56:21 +00:00
return sequence_list
def example_usage ( ) :
cfg . trainer . backend = " local "
cfg . hyperparameters . gradient_accumulation_steps = 1
if cfg . audio_backend == " dac " :
cfg . sample_rate = 44_100
from functools import partial
from einops import repeat
from tqdm import tqdm
from . . emb . qnt import decode_to_file , unload_model , trim_random , repeat_extend_audio , concat_audio , merge_audio
from . . engines import Engine , Engines
from . . utils import wrapper as ml
import numpy as np
import re
device = " cuda "
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""
if " mamba " in cfg . model . arch_type :
cfg . model . resp_levels = 1
"""
# cfg.model.loss_factors = {}
def tokenize ( content ) :
return torch . tensor ( cfg . tokenizer . encode ( content ) )
def _load_quants ( path ) - > Tensor :
qnt = np . load ( path , allow_pickle = True ) [ ( ) ]
return torch . from_numpy ( qnt [ " codes " ] . astype ( np . int16 ) ) [ 0 , : cfg . model . resp_levels , : ] . t ( ) . to ( torch . int16 )
qnt = _load_quants ( f " ./data/qnt. { ' dac ' if cfg . audio_backend == ' dac ' else ' enc ' } " )
noise = _load_quants ( f " ./data/noise. { ' dac ' if cfg . audio_backend == ' dac ' else ' enc ' } " )
text_list = [
tokenize ( " ˈ aɪ wɪ l nˌɑ ː t ˈ æsk ɐ sˈ ɛkənd tˈ aɪ m" ) . to ( device ) ,
#tokenize("ˈ aɪ wɪ l nˌɑ ː t ˈ æsk").to(device),
]
proms_list = [
qnt [ : cfg . dataset . frames_per_second , : ] . to ( device ) ,
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resps_list = [
qnt [ : , : ] . to ( device ) ,
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text_list = text_list [ : 1 ]
proms_list = proms_list [ : 1 ]
resps_list = resps_list [ : 1 ]
batch_size = len ( text_list )
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {
' n_text_tokens ' : 256 ,
' n_audio_tokens ' : 1024 ,
' d_model ' : 1024 , # 256, # 1024, # 1536
' n_heads ' : 16 , # 4, # 16, # 24
' n_layers ' : 12 , # 32
' n_experts ' : 1 ,
' p_dropout ' : 0.1 ,
' l_padding ' : 8 if cfg . optimizations . fp8 else 0 ,
' config ' : cfg . model
}
"""
try :
kwargs [ ' config ' ] = cfg . model
except Exception as e :
pass
"""
bos_id , space_id , eos_id = cfg . tokenizer . encode ( " " )
tasks = cfg . dataset . tasks_list
model = AR ( * * kwargs ) . to ( device )
steps = 75 * len ( tasks ) * cfg . model . experimental . causal_size
optimizer = cfg . hyperparameters . optimizer . lower ( ) if cfg . yaml_path is not None else " prodigy "
scheduler = cfg . hyperparameters . scheduler . lower ( ) if cfg . yaml_path is not None else " "
learning_rate = cfg . hyperparameters . learning_rate if cfg . yaml_path is not None else None
if cfg . optimizations . dadaptation :
# do not combine the two
if scheduler == " schedulefree " :
scheduler = " "
learning_rate = 1.0
if optimizer == " prodigy " :
if learning_rate is None :
learning_rate = 1.0
optimizer = ml . Prodigy
elif optimizer == " adagrad " :
if learning_rate is None :
learning_rate = 1.0e-2
optimizer = ml . Adagrad
elif optimizer == " adamw " :
if learning_rate is None :
learning_rate = 1.0e-4
optimizer = ml . AdamW
elif optimizer == " sdg " :
if learning_rate is None :
learning_rate = 1.0e-4
optimizer = ml . SGD
else :
raise ValueError ( f " Unrecognized optimizer: { optimizer } " )
2024-08-29 18:27:16 +00:00
_logger . info ( f " Optimizer: { optimizer } \t Learning rate: { learning_rate } " )
2024-08-05 00:56:21 +00:00
optimizer = optimizer ( model . parameters ( ) , lr = learning_rate )
if scheduler == " schedulefree " :
if isinstance ( optimizer , ml . AdamW ) :
scheduler = ml . schedulefree . AdamWScheduleFree
elif isinstance ( optimizer , ml . SGD ) :
scheduler = ml . schedulefree . SGDScheduleFree
else :
scheduler = None
if scheduler is not None :
2024-08-29 18:27:16 +00:00
_logger . info ( f " Scheduler: { scheduler } " )
2024-08-05 00:56:21 +00:00
optimizer = scheduler ( model . parameters ( ) , lr = learning_rate )
if cfg . optimizations . replace and cfg . optimizations . linear :
model = ml . replace_linear ( model )
if cfg . optimizations . replace and cfg . optimizations . embedding :
model = ml . replace_embedding ( model )
"""
cfg . optimizations . model_offloading = {
" devices " : [ " cuda:0 " , " cpu " ] ,
# "limits": [ 0.9, -1 ],
" assign " : [ [ f ' layers. { i } . ' for i in range ( 0 , 10 ) ] , [ f ' layers. { i } . ' for i in range ( 11 , 12 ) ] + [ " model.norm " ] ] ,
# "limits": [ 256 * (1024 ** 2), -1 ]
}
"""
engine = Engine ( model = model , optimizer = optimizer )
engines = Engines ( { " ar " : engine } )
engines . setup ( )
"""
if cfg . optimizations . model_offloading :
model = ml . offload_model ( model , policy = cfg . optimizations . model_offloading )
"""
"""
torch . save ( {
' module ' : model . state_dict ( )
} , f " ./data/ { cfg . model . arch_type } .pth " )
"""
2024-08-29 18:27:16 +00:00
_logger . info ( f " AR ( { cfg . model . arch_type } , { cfg . audio_backend } ) parameter count: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) } " )
2024-08-05 00:56:21 +00:00
@torch.no_grad ( )
def sample_data ( task = None ) :
texts = [ ]
proms = [ ]
resps = [ ]
for i in range ( batch_size ) :
if task is None :
task = random . choice ( tasks )
text = text_list [ i ]
prom = proms_list [ i ]
resp = resps_list [ i ]
# do nothing
if task == " tts " :
. . .
elif task == " tts-c " :
trim_length = int ( random . uniform ( cfg . dataset . prompt_duration_range [ 0 ] , cfg . dataset . prompt_duration_range [ 1 ] ) * cfg . dataset . frames_per_second )
prom = resp [ : trim_length ]
resp = resp [ trim_length : ]
elif task == " ns " or task == " sr " :
# extend the noise to fill the target audio
noise_ext = repeat_extend_audio ( noise , resp . shape [ 0 ] )
# create the input prompt by merging the target audio with the noise
prom = merge_audio ( resp . cpu ( ) , noise_ext , scale = [ 1 , cfg . dataset . noise_scale ] , device = cfg . dataset . reencode_device )
# set the target to just be the noise if <sr>
if task == " sr " :
resp = noise_ext
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ bos_id , eos_id ] , device = device , dtype = torch . uint8 )
texts . append ( text . to ( device ) )
proms . append ( prom . to ( device ) )
resps . append ( resp . to ( device ) )
return texts , proms , resps
@torch.inference_mode ( )
def sample ( name , steps = 1000 , task = None ) :
engine . eval ( )
texts , proms , resps = sample_data ( task )
resps = engine ( texts , proms , max_steps = steps , sampling_temperature = 0.95 )
for i , o in enumerate ( resps ) :
_ = decode_to_file ( o . to ( dtype = torch . int32 ) , f " data/ { cfg . model . arch_type } . { cfg . audio_backend } . { i } . { task } . { name } .wav " , device = device )
unload_model ( )
def train ( ) :
engine . train ( )
t = trange ( steps )
for i in t :
texts , proms , resps = sample_data ( )
stats = { " step " : i }
stats | = engine . traverse ( text_list = texts , proms_list = proms , resps_list = resps )
stats | = { " grad_norm " : engine . get_global_grad_norm ( ) }
tqdm . write ( f " { stats } " )
"""
torch . save ( {
' module ' : model . state_dict ( )
} , f " ./data/ { cfg . model . arch_type } .pth " )
"""
#sample("init", 5)
train ( )
"""
if cfg . optimizations . compile :
model = ml . compile_model ( model , backend = cfg . optimizations . compile )
"""
for task in tasks :
sample ( " final " , task = task )
engines . quit ( )
if __name__ == " __main__ " :
example_usage ( )