2023-08-02 21:53:35 +00:00
# todo: clean this mess up
from . config import cfg
2024-10-10 18:40:25 +00:00
from . data import create_train_val_dataloader , get_random_prompt , tokenize
from . emb import qnt , g2p
2023-08-02 21:53:35 +00:00
from . utils import setup_logging , to_device , trainer , flatten_dict , do_gc
2024-06-04 02:28:49 +00:00
from . data import fold_inputs , unfold_outputs
2024-06-09 16:39:43 +00:00
from . utils . distributed import is_global_leader
2023-08-02 21:53:35 +00:00
import auraloss
import json
import logging
import random
import torch
import torch . nn . functional as F
import traceback
2024-06-09 16:22:52 +00:00
import shutil
2023-08-02 21:53:35 +00:00
from collections import defaultdict
from tqdm import tqdm
2024-05-25 22:39:51 +00:00
import argparse
2023-08-02 21:53:35 +00:00
_logger = logging . getLogger ( __name__ )
2023-08-04 01:26:36 +00:00
2024-05-25 22:39:51 +00:00
mel_stft_loss = auraloss . freq . MelSTFTLoss ( cfg . sample_rate , device = " cpu " )
2024-12-06 05:05:52 +00:00
def train_feeder ( engine , batch , teacher = None ) :
2025-01-06 01:05:00 +00:00
engine . tokens_processed + = sum ( [ text . shape [ 0 ] for text in batch [ " text " ] ] )
engine . tokens_processed + = sum ( [ resps . shape [ 0 ] for resps in batch [ " resps " ] ] )
2025-01-06 05:53:17 +00:00
2023-09-02 17:23:40 +00:00
with torch . autocast ( " cuda " , dtype = cfg . trainer . dtype , enabled = cfg . trainer . amp ) :
2024-06-29 03:39:05 +00:00
batch_size = len ( batch [ " text " ] )
engine . current_batch_size = batch_size
2024-06-29 14:11:28 +00:00
2024-12-07 05:53:46 +00:00
output = engine (
2024-07-27 20:36:05 +00:00
text_list = batch [ " text " ] ,
proms_list = batch [ " proms " ] ,
resps_list = batch [ " resps " ] ,
lang_list = batch [ " lang " ] ,
tone_list = batch [ " tone " ] ,
task_list = batch [ " task " ] ,
2025-01-05 18:47:03 +00:00
raw_text_list = batch [ " raw_text " ] ,
2024-07-27 20:36:05 +00:00
training = True ,
)
2023-08-04 01:26:36 +00:00
2024-12-07 05:53:46 +00:00
# get soft targets from teacher
if teacher is not None :
# extract inputs forwarded to model
inputs = output . inputs
# grab the teacher's logits
with torch . no_grad ( ) :
teacher_output = teacher . forward_super (
inputs = inputs ,
)
# KD hyperparameters
T = cfg . hyperparameters . teacher_temperature
A = cfg . hyperparameters . teacher_alpha
L = cfg . hyperparameters . teacher_loss_fn
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
# we could recreate the target sequence with the ignore indices put in, but that's agony
2024-12-07 15:52:51 +00:00
student_logits = [ logit / T for logit in output . logits ]
teacher_logits = [ logit / T for logit in teacher_output . logits ]
2024-12-07 18:31:54 +00:00
if engine . module . ignore_inputs_for_loss :
2024-12-07 05:53:46 +00:00
task_outputs = {
" tts " : " resp " ,
" stt " : " text " ,
" len " : " len " ,
}
output_lens = [ 0 for _ in range ( batch_size ) ]
for batch_index , _batch in enumerate ( inputs ) :
task_type = " tts "
for name , input in _batch :
if name == " task " :
task_type = input
for name , input in _batch :
if name == task_outputs . get ( task_type , name ) :
output_lens [ batch_index ] = input . shape [ 0 ]
# create probability distributions (literature says to have the students already log'd but not the teacher)
2024-12-07 15:52:51 +00:00
student_logits = [ logit [ - l : ] for logit , l in zip ( student_logits , output_lens ) ]
teacher_logits = [ logit [ - l : ] for logit , l in zip ( teacher_logits , output_lens ) ]
2024-12-07 05:53:46 +00:00
if L == " kl " :
2024-12-07 18:31:54 +00:00
student_probs = [ F . log_softmax ( logit , dim = - 1 ) for logit in student_logits ]
teacher_probs = [ F . log_softmax ( logit , dim = - 1 ) for logit in teacher_logits ]
2024-12-07 15:52:51 +00:00
2024-12-07 18:39:01 +00:00
soft_losses = [ F . kl_div ( student , teacher , reduction = ' batchmean ' , log_target = True ) for student , teacher in zip ( student_probs , teacher_probs ) ]
2024-12-07 18:31:54 +00:00
elif L == " mse " :
soft_losses = [ F . mse_loss ( student , teacher ) for student , teacher in zip ( student_logits , teacher_logits ) ]
2024-12-07 05:53:46 +00:00
for k in engine . module . loss . keys ( ) :
engine . module . loss [ k ] * = ( 1.0 - A )
2024-12-07 18:31:54 +00:00
engine . module . loss [ L ] = torch . stack ( soft_losses ) . sum ( ) * A * ( T * * 2 ) / batch_size
2024-12-07 05:53:46 +00:00
2023-09-02 01:58:29 +00:00
losses = engine . gather_attribute ( " loss " )
stat = engine . gather_attribute ( " stats " )
2023-08-04 01:26:36 +00:00
2023-09-02 01:58:29 +00:00
loss = torch . stack ( [ * losses . values ( ) ] ) . sum ( )
2023-08-04 01:26:36 +00:00
stats = { }
stats | = { k : v . item ( ) for k , v in losses . items ( ) }
2023-08-05 20:25:41 +00:00
stats | = { k : v . item ( ) for k , v in stat . items ( ) }
2023-08-04 01:26:36 +00:00
return loss , stats
@torch.inference_mode ( )
2024-10-10 18:40:25 +00:00
def run_eval ( engines , eval_name , dl , args = None ) :
2023-08-04 01:26:36 +00:00
stats = defaultdict ( list )
stats [ ' loss ' ] = [ ]
2024-10-10 18:40:25 +00:00
if cfg . evaluation . size == 0 :
return
2023-08-04 01:26:36 +00:00
def process ( name , batch , resps_list ) :
2023-08-19 06:16:46 +00:00
for speaker , path , ref , hyp , prom , task in zip ( batch [ " spkr_name " ] , batch [ " path " ] , batch [ " resps " ] , resps_list , batch [ " proms " ] , batch [ " task " ] ) :
2023-08-04 01:26:36 +00:00
if len ( hyp ) == 0 :
continue
filename = f ' { speaker } _ { path . parts [ - 1 ] } '
2023-08-19 06:16:46 +00:00
if task != " tts " :
filename = f " { filename } _ { task } "
2024-07-19 14:16:37 +00:00
# flatten prom
2024-07-23 00:36:07 +00:00
if not isinstance ( prom , torch . Tensor ) and prom is not None :
2024-07-19 14:16:37 +00:00
prom = torch . concat ( [ p for p in prom if isinstance ( p , torch . Tensor ) ] )
2023-08-04 01:26:36 +00:00
# to-do, refine the output dir to be sane-er
ref_path = ( cfg . log_dir / str ( engines . global_step ) / " ref " / filename ) . with_suffix ( " .wav " )
hyp_path = ( cfg . log_dir / str ( engines . global_step ) / name / eval_name / filename ) . with_suffix ( " .wav " )
prom_path = ( cfg . log_dir / str ( engines . global_step ) / name / " prom " / filename ) . with_suffix ( " .wav " )
hyp_path . parent . mkdir ( parents = True , exist_ok = True )
ref_path . parent . mkdir ( parents = True , exist_ok = True )
prom_path . parent . mkdir ( parents = True , exist_ok = True )
hyp_audio , sr = qnt . decode_to_file ( hyp , hyp_path )
2024-10-10 18:40:25 +00:00
if ref is not None :
ref_audio , sr = qnt . decode_to_file ( ref , ref_path )
2024-07-23 00:36:07 +00:00
if prom is not None :
prom_audio , sr = qnt . decode_to_file ( prom , prom_path )
2023-08-04 01:26:36 +00:00
2024-10-10 18:40:25 +00:00
# naive loss calculation
# to-do: find a better way to calculate this / a better metric
if ref is not None :
min_length = min ( ref_audio . shape [ - 1 ] , hyp_audio . shape [ - 1 ] )
ref_audio = ref_audio [ . . . , 0 : min_length ]
hyp_audio = hyp_audio [ . . . , 0 : min_length ]
stats [ ' loss ' ] . append ( mel_stft_loss ( hyp_audio [ None , : , : ] , ref_audio [ None , : , : ] ) . item ( ) )
2023-08-04 01:26:36 +00:00
2023-08-17 23:56:37 +00:00
processed = 0
2023-08-19 02:19:47 +00:00
while processed < cfg . evaluation . size :
2024-11-11 23:00:49 +00:00
# directly randomly sample
if eval_name == " subtrain " :
# sample from dataset
# to-do: derive from current iteration
samples = [ to_device ( dl . dataset [ random . randint ( 0 , len ( dl . dataset ) ) ] , cfg . device ) for sample in range ( cfg . evaluation . batch_size ) ]
# collate manually
batch = { k : [ s [ k ] for s in samples ] for k in samples [ 0 ] }
else :
batch = to_device ( next ( iter ( dl ) ) , cfg . device )
2024-06-29 14:11:28 +00:00
# limit to eval batch size in the event we somehow have a weird dataloader
for key in batch . keys ( ) :
batch [ key ] = batch [ key ] [ : cfg . evaluation . batch_size ]
2024-09-06 21:59:56 +00:00
batch_size = len ( batch [ " text " ] )
2024-10-10 18:40:25 +00:00
# to-do: eval for text tasks
has_stt = False
for i , task in enumerate ( batch [ " task " ] ) :
# easier to just change it to a tts task than drop stt tasks from the batch
if task == " stt " :
# has_stt = True
batch [ " task " ] [ i ] = " tts "
batch [ " proms " ] [ i ] = batch [ " resps " ] [ i ] [ : 75 * 3 , : ]
2025-01-06 05:53:17 +00:00
elif task != " tts " :
batch [ " task " ] [ i ] = " tts "
2024-10-10 18:40:25 +00:00
# random prompts requested
if args and args . eval_random_text_prompts and eval_name == " subtrain " :
for i , _ in enumerate ( batch [ " text " ] ) :
batch [ " text " ] [ i ] = get_random_prompt ( tokenized = True ) . to ( device = cfg . device )
batch [ " resps " ] [ i ] = None
2023-08-04 01:26:36 +00:00
2024-10-10 18:40:25 +00:00
processed + = batch_size
2024-06-04 02:28:49 +00:00
for name in engines :
engine = engines [ name ]
2024-10-23 03:06:22 +00:00
base_kwargs = dict (
2024-09-06 19:30:12 +00:00
text_list = batch [ " text " ] ,
2024-09-06 21:59:56 +00:00
proms_list = batch [ " proms " ] ,
2024-09-06 19:30:12 +00:00
lang_list = batch [ " lang " ] ,
task_list = batch [ " task " ] ,
2024-11-10 18:48:41 +00:00
training = False ,
2024-09-06 19:30:12 +00:00
)
2024-06-30 15:37:33 +00:00
if engine . hyper_config . experimental . hf :
2024-10-23 03:06:22 +00:00
resps_list = engine ( * * base_kwargs )
2024-07-27 20:36:05 +00:00
elif " len " in engine . hyper_config . capabilities :
2024-11-16 21:49:06 +00:00
kwargs = base_kwargs | cfg . evaluation . kwargs
2024-11-07 05:14:16 +00:00
max_steps = kwargs . pop ( " max_steps " , 500 )
2024-11-10 04:57:34 +00:00
2024-11-16 21:49:06 +00:00
if " denoise_start " in kwargs :
2024-11-10 04:57:34 +00:00
len_list = [ resp . shape [ 0 ] for resp in batch [ " resps " ] ]
kwargs [ " resps_list " ] = [ resp [ : , : 1 ] for resp in batch [ " resps " ] ]
2024-11-11 02:37:50 +00:00
else :
len_list = engine ( max_steps = 5 , * * kwargs )
len_list = [ min ( l , max_steps ) for l in len_list ]
2024-10-23 03:06:22 +00:00
2024-11-16 21:49:06 +00:00
kwargs = base_kwargs | cfg . evaluation . kwargs
2024-10-23 03:06:22 +00:00
resps_list = engine ( * * kwargs , len_list = len_list )
2024-06-04 02:28:49 +00:00
else :
2024-07-27 20:36:05 +00:00
if " ar " in engine . hyper_config . capabilities :
2024-11-17 16:23:40 +00:00
kwargs = base_kwargs | cfg . evaluation . kwargs
2024-10-23 03:06:22 +00:00
resps_list = engine ( * * kwargs )
2024-06-05 15:30:04 +00:00
else :
2024-07-27 20:36:05 +00:00
resps_list = [ resp [ : , 0 ] for resp in batch [ " resps " ] ]
2024-06-05 15:30:04 +00:00
2024-07-27 20:36:05 +00:00
if " nar " in engine . hyper_config . capabilities :
2024-11-16 21:49:06 +00:00
kwargs = base_kwargs | cfg . evaluation . kwargs
2024-10-23 03:06:22 +00:00
resps_list = engine ( * * kwargs , resps_list = resps_list )
2023-08-04 01:26:36 +00:00
process ( name , batch , resps_list )
2024-09-06 21:59:56 +00:00
# evaluate why it's so slow
if has_stt :
max_steps = max ( [ text . shape [ 0 ] for text in batch [ " text " ] ] )
kwargs [ " text_list " ] = None
kwargs [ " task_list " ] = [ " stt " for _ in range ( batch_size ) ]
kwargs [ " proms_list " ] = [ [ " stt " ] for _ in range ( batch_size ) ]
kwargs [ " resps_list " ] = batch [ " resps " ]
text_list = engine ( * * kwargs , max_steps = max_steps , sampling_temperature = 0.0 )
text_list = [ cfg . tokenizer . decode ( text ) for i , text in enumerate ( text_list ) ]
_logger . info ( f " Validation Metrics (STT): { text_list } " )
2024-10-10 18:40:25 +00:00
stats = { k : sum ( v ) / len ( v ) for k , v in stats . items ( ) if v }
2023-09-13 18:19:11 +00:00
engines_stats = {
f ' { name } . { eval_name } ' : stats ,
" it " : engines . global_step ,
}
2023-08-19 01:58:07 +00:00
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
2023-08-04 01:26:36 +00:00
2024-08-29 18:27:16 +00:00
_logger . info ( f " Validation Metrics: { json . dumps ( engines_stats ) } . " )
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
2023-10-21 14:55:38 +00:00
def train ( ) :
2024-05-25 22:39:51 +00:00
parser = argparse . ArgumentParser ( " VALL-E TTS " )
2024-06-04 03:34:47 +00:00
parser . add_argument ( " --eval " , action = " store_true " , default = None )
2024-10-10 18:40:25 +00:00
parser . add_argument ( " --eval-random-text-prompts " , action = " store_true " , default = None )
#parser.add_argument("--eval-random-audio-prompts", action="store_true", default=None)
2024-06-04 03:35:55 +00:00
args , unknown = parser . parse_known_args ( )
2024-05-25 22:39:51 +00:00
2024-06-09 16:22:52 +00:00
# create log folder
2023-08-03 03:57:10 +00:00
setup_logging ( cfg . log_dir )
2024-06-09 16:22:52 +00:00
# copy config yaml to backup
2024-06-09 16:39:43 +00:00
if cfg . yaml_path is not None and is_global_leader ( ) :
2024-06-09 16:22:52 +00:00
shutil . copy ( cfg . yaml_path , cfg . log_dir / " config.yaml " )
2024-11-01 23:36:44 +00:00
# create dataloaders
2024-11-11 23:00:49 +00:00
train_dl , val_dl = create_train_val_dataloader ( )
2024-11-01 23:36:44 +00:00
# evaluation lambda
2023-08-02 21:53:35 +00:00
def eval_fn ( engines ) :
2024-05-25 22:46:52 +00:00
do_gc ( )
engines . eval ( )
# wrapped in a try block because it's sometimes prone to breaking
2023-08-02 21:53:35 +00:00
try :
2024-11-11 23:00:49 +00:00
run_eval ( engines , " subtrain " , train_dl , args )
2024-10-10 18:40:25 +00:00
run_eval ( engines , " val " , val_dl , args )
2023-08-02 21:53:35 +00:00
except Exception as e :
2024-08-29 18:27:16 +00:00
_logger . warning ( f " Error occurred while performing eval: { str ( e ) } " )
_logger . warning ( traceback . format_exc ( ) )
2023-08-02 21:53:35 +00:00
2024-05-25 22:46:52 +00:00
engines . train ( )
2023-08-02 21:53:35 +00:00
qnt . unload_model ( )
do_gc ( )
2024-11-01 23:36:44 +00:00
# unload EnCodec if it's already loaded
2023-08-02 21:53:35 +00:00
qnt . unload_model ( )
2024-11-01 23:36:44 +00:00
# only eval is requested
2024-05-25 22:39:51 +00:00
if args . eval :
return eval_fn ( engines = trainer . load_engines ( ) )
2023-10-21 14:55:38 +00:00
"""
2024-11-01 23:36:44 +00:00
# start web UI
2023-10-21 14:55:38 +00:00
if cfg . trainer . load_webui :
from . webui import start
start ( lock = False )
"""
2024-11-01 23:36:44 +00:00
# pre-training config validation
if cfg . model . experimental . layerskip and cfg . trainer . weight_dtype == " float16 " :
2024-11-03 15:58:29 +00:00
_logger . warning ( f " Training with LayerSkip enabled with float16 may result in frying the model if the loss scale gets too small (<=8K) or with too large of a de facto batch size (>512 samples). " )
2023-10-21 14:55:38 +00:00
2024-11-01 23:36:44 +00:00
# train
2023-08-02 21:53:35 +00:00
trainer . train (
train_dl = train_dl ,
train_feeder = train_feeder ,
eval_fn = eval_fn ,
)
if __name__ == " __main__ " :
2024-05-04 16:48:26 +00:00
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
2023-10-21 14:55:38 +00:00
train ( )