2023-08-02 21:53:35 +00:00
import copy
import diskcache
import h5py
import json
import os
import subprocess
import sys
import time
2023-08-14 03:07:45 +00:00
import torch
2024-04-17 02:04:48 +00:00
from dataclasses import asdict , dataclass , field
2023-08-02 21:53:35 +00:00
from functools import cached_property
from pathlib import Path
from omegaconf import OmegaConf
2023-08-14 03:56:28 +00:00
from . utils . distributed import world_size
2024-04-21 19:49:18 +00:00
# Yuck
from transformers import PreTrainedTokenizerFast
2023-08-02 21:53:35 +00:00
@dataclass ( )
class _Config :
cfg_path : str | None = None
@property
def relpath ( self ) :
return Path ( self . cfg_path )
2023-08-27 00:53:23 +00:00
@property
def cache_dir ( self ) :
return self . relpath / " .cache "
2024-04-29 03:28:29 +00:00
@property
def data_dir ( self ) :
return self . relpath / " data "
@property
def metadata_dir ( self ) :
return self . relpath / " metadata "
2023-08-02 21:53:35 +00:00
@property
def ckpt_dir ( self ) :
return self . relpath / " ckpt "
@property
def log_dir ( self ) :
return self . relpath / " logs " / str ( self . start_time )
@cached_property
def start_time ( self ) :
return int ( time . time ( ) )
@cached_property
def git_commit ( self ) :
try :
cmd = " git rev-parse HEAD "
return subprocess . check_output ( cmd . split ( ) ) . decode ( " utf8 " ) . strip ( )
except :
return " "
@cached_property
def git_status ( self ) :
try :
cmd = " git status "
return subprocess . check_output ( cmd . split ( ) ) . decode ( " utf8 " ) . strip ( )
except :
return " "
def dumps ( self ) :
data = { k : getattr ( self , k ) for k in dir ( self ) if not k . startswith ( " __ " ) }
data = { k : v for k , v in data . items ( ) if not callable ( v ) }
return json . dumps ( data , indent = 2 , default = str )
def dump ( self , path = None ) :
if path is None :
path = self . log_dir / " cfg.json "
path . parent . mkdir ( parents = True , exist_ok = True )
with open ( path , " w " ) as f :
f . write ( self . dumps ( ) )
@staticmethod
def _is_cfg_argv ( s ) :
return " = " in s and " -- " not in s
@classmethod
def from_yaml ( cls , yaml_path ) :
return cls . from_cli ( [ f ' yaml= " { yaml_path } " ' ] )
@classmethod
def from_cli ( cls , args = sys . argv ) :
cli_cfg = OmegaConf . from_cli ( [ s for s in args if cls . _is_cfg_argv ( s ) ] )
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
sys . argv = [ s for s in sys . argv if not cls . _is_cfg_argv ( s ) ]
if cli_cfg . get ( " help " ) :
print ( f " Configurable hyperparameters with their default values: " )
print ( json . dumps ( asdict ( cls ( ) ) , indent = 2 , default = str ) )
exit ( )
if " yaml " in cli_cfg :
yaml_cfg = OmegaConf . load ( cli_cfg . yaml )
yaml_path = Path ( cli_cfg . yaml ) . absolute ( )
cfg_path = Path ( * yaml_path . relative_to ( Path . cwd ( ) ) . parts [ : - 1 ] )
cfg_path = cfg_path . with_suffix ( " " )
cfg_path = f ' ./ { cfg_path } '
yaml_cfg . setdefault ( " cfg_path " , cfg_path )
cli_cfg . pop ( " yaml " )
else :
yaml_cfg = { }
merged = OmegaConf . merge ( yaml_cfg , cli_cfg )
return cls ( * * dict ( merged ) )
def __repr__ ( self ) :
return str ( self )
def __str__ ( self ) :
return self . dumps ( )
@dataclass ( )
class Dataset :
training : list [ Path ] = field ( default_factory = lambda : [ ] )
validation : list [ Path ] = field ( default_factory = lambda : [ ] )
2023-08-19 04:55:40 +00:00
noise : list [ Path ] = field ( default_factory = lambda : [ ] )
2023-08-02 21:53:35 +00:00
temp : list [ Path ] = field ( default_factory = lambda : [ ] )
speaker_name_getter : str = " lambda p: f ' {p.parts[-3]} _ {p.parts[-2]} ' "
2023-10-12 01:38:40 +00:00
speaker_group_getter : str = " lambda p: f ' {p.parts[-3]} ' "
speaker_languages : dict = field ( default_factory = lambda : { } ) # dict where keys are the language codes and values are the speaker groups
2023-08-02 21:53:35 +00:00
2023-10-05 21:39:46 +00:00
hdf5_name : str = " data.h5 "
2023-08-02 21:53:35 +00:00
use_hdf5 : bool = False
2023-08-27 00:53:23 +00:00
use_metadata : bool = False
2023-08-18 19:47:48 +00:00
hdf5_flag : str = " a "
2023-08-02 21:53:35 +00:00
validate : bool = True
workers : int = 8
cache : bool = True
phones_range : list [ int ] = field ( default_factory = lambda : [ 4 , 256 ] )
duration_range : list [ float ] = field ( default_factory = lambda : [ 1.0 , 12.0 ] )
2024-05-11 14:50:54 +00:00
prompt_duration_range : list [ float ] = field ( default_factory = lambda : [ 3.0 , 6.0 ] )
2023-10-19 01:38:33 +00:00
min_utterances : int = 2
2023-08-02 21:53:35 +00:00
random_utterance : float = 1.0
max_prompts : int = 3
2024-05-11 14:50:54 +00:00
prompt_duration : float = 0.0 # legacy
2023-10-11 22:32:45 +00:00
max_resps : int = 1
p_resp_append : float = 1.0
2023-08-02 21:53:35 +00:00
2023-08-17 00:39:21 +00:00
sample_type : str = " path " # path | speaker
2024-05-25 16:07:52 +00:00
2023-08-19 05:16:08 +00:00
tasks_list : list [ str ] = field ( default_factory = lambda : [ " tts " ] )
2024-05-04 17:05:41 +00:00
2024-05-05 04:49:15 +00:00
_frames_per_second : int = 0 # allows setting your own hint
2024-05-04 17:05:41 +00:00
@cached_property
def frames_per_second ( self ) :
if self . _frames_per_second > 0 :
return self . _frames_per_second
2024-05-05 04:49:15 +00:00
2024-05-25 16:07:52 +00:00
if cfg . audio_backend == " dac " :
2024-05-05 04:49:15 +00:00
# using the 44KHz model with 24KHz sources has a frame rate of 41Hz
if cfg . variable_sample_rate and cfg . sample_rate == 24_000 :
return 41
if cfg . sample_rate == 44_000 :
return 86
if cfg . sample_rate == 16_000 :
return 50
# 24Khz Encodec / Vocos and incidentally DAC are all at 75Hz
return 75
2023-08-17 00:39:21 +00:00
2023-08-27 00:53:23 +00:00
@property
def min_phones ( self ) :
return self . phones_range [ 0 ]
@property
def max_phones ( self ) :
return self . phones_range [ 1 ]
@property
def min_duration ( self ) :
return self . duration_range [ 0 ]
@property
def max_duration ( self ) :
return self . duration_range [ 1 ]
2024-06-04 02:28:49 +00:00
# I really need to clean this up
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Model :
2024-04-16 00:54:32 +00:00
_max_levels : int = 0
_embeddings : str | None = None
2023-10-02 21:52:42 +00:00
name : str = " " # vanity name for the model
version : int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
size : str | dict = " full " # preset string or explicitly defined dimensionality
resp_levels : int = 1 # RVQ-bin levels this model targets for outputs
prom_levels : int = 8 # RVQ-bin levels this model accepts as an input prompt
2023-10-12 01:38:40 +00:00
tasks : int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
langs : int = 1 # defined languages
2024-04-16 00:54:32 +00:00
tones : int = 1 # defined tones
2023-12-21 00:45:58 +00:00
experts : int = 1
2023-10-02 21:52:42 +00:00
arch_type : str = " retnet " # or "transformer""
training : bool = True # unneeded now
interleave : bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
p_ar_level : float | str = " auto " # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
frozen_params : list [ str ] = field ( default_factory = lambda : [ ] ) # frozen parameters that are not updated when training
2024-05-11 22:14:05 +00:00
attention : str = " auto "
2024-04-29 23:24:05 +00:00
audio_embedding_sums : bool = True
2024-05-11 21:31:05 +00:00
dropout : float = 0.1 # adjustable dropout value
2024-06-01 17:34:59 +00:00
loss_factors : dict = field ( default_factory = lambda : { " text " : 0.1 , " prom " : 0.0 , " resp " : 1.0 } )
2024-06-01 03:22:28 +00:00
kv_heads : int = 0
2024-06-04 02:28:49 +00:00
experimental : bool = False
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
def get ( self , name = None ) :
return [ self ] if not name or self . name == name else [ ]
2024-05-19 16:23:56 +00:00
def loss_factor ( self , k ) :
return self . loss_factors [ k ] if k in self . loss_factors else 1.0
2024-04-16 00:54:32 +00:00
@property
def max_levels ( self ) :
return self . _max_levels if self . _max_levels > 0 else self . prom_levels
2024-04-09 01:14:51 +00:00
@property
# required for fp8 as the lengths needs to be divisible by 8
def input_alignment ( self ) :
2024-05-03 01:08:59 +00:00
return 8 if cfg . optimizations . fp8 else 0
2024-04-09 01:14:51 +00:00
2023-08-02 21:53:35 +00:00
@property
def full_name ( self ) :
name = [ self . name ]
2024-04-13 17:43:35 +00:00
if isinstance ( self . size , dict ) :
if hasattr ( self . size , " label " ) and self . size [ ' label ' ] :
name . append ( f " { self . size [ ' label ' ] } " )
elif isinstance ( self . size , str ) and self . size != " full " :
2023-08-02 21:53:35 +00:00
name . append ( self . size )
if self . arch_type != " transformer " :
2023-12-26 03:20:32 +00:00
if self . experts > 1 :
2023-12-21 00:45:58 +00:00
name . append ( f ' { self . experts } x ' + self . arch_type . replace ( " / " , " - " ) )
else :
name . append ( self . arch_type . replace ( " / " , " - " ) )
2023-08-02 21:53:35 +00:00
2024-05-03 01:08:59 +00:00
if cfg . optimizations . bitnet :
2024-03-01 16:32:35 +00:00
name . append ( " bitnet " )
2023-09-04 03:46:08 +00:00
if self . interleave :
name . append ( " interleaved " )
2023-12-21 00:45:58 +00:00
else :
2024-04-16 00:54:32 +00:00
name . append ( f ' { cfg . model . prom_levels } ' )
2023-09-04 03:46:08 +00:00
2023-08-02 21:53:35 +00:00
return " - " . join ( name )
@property
def tokens ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " tokens " ) :
return self . size [ ' tokens ' ]
2023-08-02 21:53:35 +00:00
return 1024
@property
def dim ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " dim " ) :
return self . size [ ' dim ' ]
if isinstance ( self . size , float ) :
return math . floor ( 1024 * self . size )
2023-08-02 21:53:35 +00:00
if self . size == " quarter " :
return 256
if self . size == " half " :
return 512
2023-09-02 02:33:51 +00:00
return 1024
2023-08-02 21:53:35 +00:00
@property
def heads ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " heads " ) :
return self . size [ ' heads ' ]
if isinstance ( self . size , float ) :
return math . floor ( 16 * self . size )
2023-08-02 21:53:35 +00:00
if self . size == " quarter " :
return 4
if self . size == " half " :
return 8
2023-09-02 02:33:51 +00:00
return 16
2023-08-02 21:53:35 +00:00
@property
def layers ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " layers " ) :
return self . size [ ' layers ' ]
2023-09-02 02:33:51 +00:00
if self . size == " double " :
return 24
2023-08-02 21:53:35 +00:00
return 12
2023-09-05 20:38:21 +00:00
@property
def activation_checkpointing ( self ) :
return cfg . trainer . activation_checkpointing
2024-06-04 02:28:49 +00:00
@property
def gradient_checkpointing ( self ) :
return cfg . trainer . gradient_checkpointing
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Hyperparameters :
batch_size : int = 8
gradient_accumulation_steps : int = 32
2024-03-02 01:20:10 +00:00
gradient_clipping : int | float = 100
2023-08-02 21:53:35 +00:00
optimizer : str = " Adamw "
2024-05-10 02:00:26 +00:00
optimizer_params : dict = field ( default_factory = lambda : { } ) # to pass through deepspeed config
2023-08-02 21:53:35 +00:00
learning_rate : float = 3.25e-4
2024-05-10 02:00:26 +00:00
warmup_steps : int = 0
2023-08-02 21:53:35 +00:00
2024-05-10 01:28:20 +00:00
scheduler : str = " "
scheduler_type : str = " " # deprecated
2024-05-10 02:00:26 +00:00
scheduler_params : dict = field ( default_factory = lambda : { } ) # to pass through deepspeed config
2024-05-10 02:25:40 +00:00
autotune : bool = False
autotune_params : dict = field ( default_factory = lambda : { } ) # to pass through deepspeed config
2024-05-10 02:00:26 +00:00
torch_optimizer : bool = False
2024-05-10 01:28:20 +00:00
torch_scheduler : bool = False
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Evaluation :
batch_size : int = 64
frequency : int = 250
size : int = 64
steps : int = 500
2023-08-04 01:26:36 +00:00
ar_temperature : float = 1.0
nar_temperature : float = 0.2
2023-08-27 17:26:12 +00:00
load_disabled_engines : bool = True
2023-08-04 01:26:36 +00:00
@dataclass ( )
class DeepSpeed :
zero_optimization_level : int = 0
use_compression_training : bool = False
2023-08-19 01:58:07 +00:00
compression_bits : int = 8
2023-10-13 03:21:43 +00:00
inferencing : bool = False
2024-05-11 18:57:43 +00:00
2024-05-10 04:15:52 +00:00
amp : bool = False
2024-05-11 18:57:43 +00:00
config : dict = field ( default_factory = lambda : { } ) # to pass through deepspeed config
2023-08-04 01:26:36 +00:00
2023-08-19 20:06:33 +00:00
@cached_property
def ds_cfg ( self ) :
2024-05-10 02:00:26 +00:00
optimizer_params = cfg . hyperparameters . optimizer_params
if ' lr ' not in optimizer_params :
optimizer_params [ " lr " ] = cfg . hyperparameters . learning_rate ,
scheduler_params = cfg . hyperparameters . scheduler_params
if ' warmup_num_steps ' not in scheduler_params :
scheduler_params [ ' warmup_num_steps ' ] = cfg . hyperparameters . warmup_steps
2023-08-04 01:26:36 +00:00
2024-05-10 02:00:26 +00:00
if ' total_num_steps ' not in scheduler_params :
2023-08-04 01:26:36 +00:00
scheduler_params [ ' total_num_steps ' ] = cfg . trainer . iterations
2024-05-10 02:25:40 +00:00
autotune_params = cfg . hyperparameters . autotune_params
if " enabled " not in autotune_params :
autotune_params [ ' enabled ' ] = True
if " results_dir " not in autotune_params :
autotune_params [ ' results_dir ' ] = str ( cfg . relpath / " autotune " / " results " )
if " exps_dir " not in autotune_params :
autotune_params [ ' exps_dir ' ] = str ( cfg . relpath / " autotune " / " exps_ " )
2024-05-10 03:48:42 +00:00
# DeepSpeed fp16 is incompatible with its AMP
2024-05-12 03:58:38 +00:00
if cfg . trainer . weight_dtype . lower ( ) == " float16 " :
2024-05-10 03:48:42 +00:00
self . amp = False
# disable local AMP
if self . amp :
cfg . trainer . amp = False
2023-08-04 01:26:36 +00:00
ds_cfg = {
" train_micro_batch_size_per_gpu " : cfg . hyperparameters . batch_size ,
" gradient_accumulation_steps " : cfg . hyperparameters . gradient_accumulation_steps ,
" optimizer " : {
" type " : cfg . hyperparameters . optimizer ,
2024-05-10 02:00:26 +00:00
" params " : optimizer_params ,
2023-09-07 14:14:03 +00:00
} if not cfg . hyperparameters . torch_optimizer else None ,
2023-08-04 01:26:36 +00:00
" scheduler " : {
2024-05-10 01:28:20 +00:00
" type " : cfg . hyperparameters . scheduler ,
2023-08-04 01:26:36 +00:00
" params " : scheduler_params ,
2024-05-10 01:28:20 +00:00
} if not cfg . hyperparameters . torch_scheduler else None ,
2023-08-04 01:26:36 +00:00
" gradient_clipping " : cfg . hyperparameters . gradient_clipping ,
" fp16 " : {
2024-05-12 03:58:38 +00:00
" enabled " : cfg . trainer . weight_dtype . lower ( ) == " float16 " ,
2024-05-10 03:48:42 +00:00
" auto_cast " : True , # ???
2024-05-11 18:57:43 +00:00
} ,
2023-08-04 01:26:36 +00:00
" bf16 " : {
2024-05-10 02:00:26 +00:00
" enabled " : cfg . trainer . weight_dtype . lower ( ) == " bfloat16 " ,
2023-08-04 01:26:36 +00:00
} ,
2024-05-10 02:00:26 +00:00
" amp " : {
2024-05-10 03:48:42 +00:00
" enabled " : self . amp ,
2024-05-10 02:25:40 +00:00
} ,
" autotuning " : autotune_params if cfg . hyperparameters . autotune else None ,
2023-08-04 01:26:36 +00:00
" compression_training " : {
" weight_quantization " : {
" shared_parameters " : {
" enabled " : True ,
" quantizer_kernel " : True ,
" schedule_offset " : 0 ,
" quantize_groups " : 64 ,
" quantize_verbose " : True ,
" quantization_type " : " symmetric " ,
" rounding " : " nearest " ,
2023-10-13 03:21:43 +00:00
" quantize_weight_in_forward " : cfg . trainer . weight_dtype . lower ( ) != " float16 " , # MoQ (quantize in optimization step) weight quantization is only supported for FP16
2023-08-04 01:26:36 +00:00
" fp16_mixed_quantize " : {
" enabled " : False ,
" quantize_change_ratio " : 1
}
} ,
" different_groups " : {
" wq1 " : {
" params " : {
2023-08-19 01:58:07 +00:00
" start_bits " : self . compression_bits ,
" target_bits " : self . compression_bits ,
2023-08-04 01:26:36 +00:00
" quantization_period " : 0
} ,
2024-05-10 02:00:26 +00:00
" modules " : [ " self_attn " , " mlp " ] # for LLaMA, need to find for other arches
2023-10-13 03:21:43 +00:00
}
}
} ,
" activation_quantization " : {
" shared_parameters " : {
" enabled " : True ,
" quantizer_kernel " : True ,
" schedule_offset " : 0 ,
" quantize_groups " : 64 ,
" quantize_verbose " : True ,
" quantization_type " : " symmetric " ,
" rounding " : " nearest " ,
" quantize_weight_in_forward " : cfg . trainer . weight_dtype . lower ( ) != " float16 " , # MoQ (quantize in optimization step) weight quantization is only supported for FP16
" fp16_mixed_quantize " : {
" enabled " : False ,
" quantize_change_ratio " : 1
}
} ,
" different_groups " : {
" aq1 " : {
" params " : {
" bits " : self . compression_bits ,
} ,
2024-05-10 02:00:26 +00:00
" modules " : [ " self_attn " , " mlp " ] # for LLaMA, need to find for other arches
2023-08-04 01:26:36 +00:00
}
}
} ,
} if self . use_compression_training else None ,
" zero_optimization " : {
" stage " : self . zero_optimization_level ,
" contiguous_gradients " : True ,
" overlap_comm " : True ,
" reduce_scatter " : True ,
" reduce_bucket_size " : 5e8 ,
" allgather_bucket_size " : 5e8 ,
" sub_group_size " : 5e8 ,
" round_robin_gradients " : True ,
" offload_optimizer " : {
" device " : " cpu " ,
" pin_memory " : True
} ,
" offload_param " : {
" device " : " cpu " ,
" pin_memory " : True
2023-08-16 02:58:16 +00:00
} ,
" zero_quantized_weights " : self . use_compression_training ,
" zero_hpz_partition_size " : world_size ( ) ,
" zero_quantized_gradients " : self . use_compression_training ,
2023-08-04 01:26:36 +00:00
} if self . zero_optimization_level > 0 else None ,
" comms_logger " : {
" enabled " : False
}
}
null_keys = [ k for k in ds_cfg if not ds_cfg [ k ] ]
for k in null_keys :
del ds_cfg [ k ]
2023-08-19 20:06:33 +00:00
if os . path . exists ( " ./data/ds_config.json " ) :
ds_cfg . update ( json . load ( open ( " ./data/ds_config.json " , " r " , encoding = " utf-8 " ) ) )
2024-05-11 18:57:43 +00:00
else :
ds_cfg . update ( self . config )
2023-08-04 01:26:36 +00:00
return ds_cfg
2023-08-02 21:53:35 +00:00
2024-05-19 16:23:56 +00:00
@dataclass ( )
2023-08-02 21:53:35 +00:00
class Trainer :
iterations : int = 100_000
save_tag : str = " step "
load_tag : str | None = None
save_on_oom : bool = True
save_on_quit : bool = True
2023-08-23 21:43:03 +00:00
export_on_save : bool = False
export_on_quit : bool = False
2023-08-02 21:53:35 +00:00
save_frequency : int = 100
2023-08-17 01:12:12 +00:00
keep_last_checkpoints : int = 0
2023-08-02 21:53:35 +00:00
load_state_dict : bool = False
load_states : bool = True
strict_loading : bool = True
2023-08-20 18:42:18 +00:00
load_module_only : bool = False
2023-08-02 21:53:35 +00:00
restart_step_count : bool = False
2024-06-04 02:28:49 +00:00
activation_checkpointing : bool | None = None # deprecated
gradient_checkpointing : bool = True
2023-09-05 20:38:21 +00:00
2023-08-02 21:53:35 +00:00
aggressive_optimizations : bool = False
2023-08-04 01:36:19 +00:00
check_for_oom : bool = True
2023-08-02 21:53:35 +00:00
gc_mode : str | None = None
2023-09-02 01:58:29 +00:00
load_disabled_engines : bool = False
2023-08-02 21:53:35 +00:00
weight_dtype : str = " float16 "
2023-09-02 01:58:29 +00:00
amp : bool = False
2024-05-04 16:48:26 +00:00
ddp : bool = False
2023-08-02 21:53:35 +00:00
2023-10-21 14:55:38 +00:00
load_webui : bool = False
2024-03-01 16:32:35 +00:00
no_logger : bool = False
2023-10-21 14:55:38 +00:00
2023-09-02 01:58:29 +00:00
backend : str = " local "
2023-08-04 01:26:36 +00:00
deepspeed : DeepSpeed = field ( default_factory = lambda : DeepSpeed )
2023-08-02 21:53:35 +00:00
2023-08-05 03:22:15 +00:00
@cached_property
def dtype ( self ) :
if self . weight_dtype == " float16 " :
return torch . float16
2023-08-14 03:07:45 +00:00
if self . weight_dtype == " bfloat16 " :
2023-08-05 03:22:15 +00:00
return torch . bfloat16
2024-04-09 01:14:51 +00:00
if self . weight_dtype == " float8_e5m2 " :
return torch . float8_e5m2
if self . weight_dtype == " float8_e4m3fn " :
return torch . float8_e4m3fn
2023-08-05 03:22:15 +00:00
return torch . float32
2024-05-12 03:23:29 +00:00
@cached_property
def scale_loss ( self ) :
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
if self . backend != " local " :
return False
return self . dtype == torch . float16
2023-08-02 23:36:26 +00:00
@dataclass ( )
class Inference :
2023-10-09 20:24:04 +00:00
backend : str = " local "
2023-08-21 02:36:02 +00:00
weight_dtype : str = " float32 "
2023-09-02 01:58:29 +00:00
amp : bool = False
2023-08-21 02:36:02 +00:00
2023-08-19 04:55:40 +00:00
normalize : bool = False # do NOT enable this unless you know exactly what you're doing
2024-05-25 16:07:52 +00:00
audio_backend : str = " " # encodec, vocos, dac
2024-04-18 01:39:35 +00:00
# legacy / backwards compat
2023-08-02 23:36:26 +00:00
use_vocos : bool = True
2024-04-18 01:39:35 +00:00
use_encodec : bool = True
use_dac : bool = True
2023-08-02 23:36:26 +00:00
2024-05-05 03:37:22 +00:00
# shit that doesn't work
2023-09-02 01:58:29 +00:00
recurrent_chunk_size : int = 0
recurrent_forward : bool = False
2023-08-21 02:36:02 +00:00
@cached_property
def dtype ( self ) :
if self . weight_dtype == " float16 " :
return torch . float16
if self . weight_dtype == " bfloat16 " :
return torch . bfloat16
2023-09-10 03:27:20 +00:00
if self . weight_dtype == " int8 " :
return torch . int8
2024-04-09 01:14:51 +00:00
if self . weight_dtype == " float8_e5m2 " :
return torch . float8_e5m2
if self . weight_dtype == " float8_e4m3fn " :
return torch . float8_e4m3fn
2023-08-21 02:36:02 +00:00
return torch . float32
2024-05-03 01:08:59 +00:00
# should be renamed to optimizations
2023-08-02 23:36:26 +00:00
@dataclass ( )
2024-05-03 01:08:59 +00:00
class Optimizations :
2024-05-10 01:28:20 +00:00
injects : bool = False # overwrites default torch classes (not recommended)
replace : bool = False # replaces modules in place with the optimized version (recommended)
2023-08-02 23:36:26 +00:00
2024-05-10 01:28:20 +00:00
linear : bool = True # inject/replace linear for BnB
embedding : bool = True # inject/replace embedding for BnB
optimizers : bool = True # inject/replace optimizers (BnB, DAdaptation)
2024-03-01 02:29:17 +00:00
2024-05-10 01:28:20 +00:00
bitsandbytes : bool = False # use bitsandbytes
dadaptation : bool = True # use dadaptation optimizer
bitnet : bool = False # use bitnet
fp8 : bool = False # use fp8
2024-04-09 01:14:51 +00:00
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Config ( _Config ) :
device : str = " cuda "
2023-09-02 01:58:29 +00:00
mode : str = " training " # "inferencing"
2023-10-09 20:24:04 +00:00
experimental : bool = False # So I can stop commenting out things when committing
2023-08-02 21:53:35 +00:00
dataset : Dataset = field ( default_factory = lambda : Dataset )
2024-04-16 00:54:32 +00:00
model : Model = field ( default_factory = lambda : Model )
2024-04-16 15:02:31 +00:00
models : dict | list | None = None # deprecated
2023-08-02 21:53:35 +00:00
hyperparameters : Hyperparameters = field ( default_factory = lambda : Hyperparameters )
evaluation : Evaluation = field ( default_factory = lambda : Evaluation )
trainer : Trainer = field ( default_factory = lambda : Trainer )
2023-08-02 23:36:26 +00:00
inference : Inference = field ( default_factory = lambda : Inference )
2024-05-03 01:08:59 +00:00
bitsandbytes : dict | list | None = None # deprecated
optimizations : Optimizations = field ( default_factory = lambda : Optimizations )
2024-04-09 01:14:51 +00:00
2024-04-21 19:49:18 +00:00
tokenizer : str = " ./tokenizer.json "
2023-08-02 21:53:35 +00:00
2024-04-19 23:36:54 +00:00
sample_rate : int = 24_000
2024-05-05 04:49:15 +00:00
variable_sample_rate : bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
2023-08-02 21:53:35 +00:00
2024-05-25 16:07:52 +00:00
audio_backend : str = " vocos "
2023-08-14 03:56:28 +00:00
@property
def distributed ( self ) :
return world_size ( ) > 1
2023-08-02 21:53:35 +00:00
@cached_property
def get_spkr ( self ) :
return eval ( self . dataset . speaker_name_getter )
2023-10-12 01:38:40 +00:00
@cached_property
def get_spkr_group ( self ) :
return eval ( self . dataset . speaker_group_getter )
2023-08-02 21:53:35 +00:00
@cached_property
def diskcache ( self ) :
2023-08-23 16:02:15 +00:00
if self . cfg_path is not None and self . dataset . cache :
2023-08-02 21:53:35 +00:00
return diskcache . Cache ( self . cache_dir ) . memoize
return lambda : lambda x : x
def load_yaml ( self , config_path ) :
tmp = Config . from_yaml ( config_path )
self . __dict__ . update ( tmp . __dict__ )
2023-08-19 14:50:07 +00:00
def load_hdf5 ( self , write = False ) :
if hasattr ( self , ' hdf5 ' ) :
self . hdf5 . close ( )
if self . distributed :
self . dataset . hdf5_flag = " r "
try :
self . hdf5 = h5py . File ( f ' { self . cfg_path } / { self . dataset . hdf5_name } ' , ' a ' if write else self . dataset . hdf5_flag ) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e :
print ( " Error while opening HDF5 file: " , f ' { self . cfg_path } / { self . dataset . hdf5_name } ' , str ( e ) )
self . dataset . use_hdf5 = False
2023-08-16 02:58:16 +00:00
def format ( self ) :
2024-05-12 19:01:52 +00:00
if isinstance ( self . dataset , type ) :
self . dataset = dict ( )
if isinstance ( self . model , type ) :
self . model = dict ( )
if isinstance ( self . hyperparameters , type ) :
self . hyperparameters = dict ( )
if isinstance ( self . evaluation , type ) :
self . evaluation = dict ( )
if isinstance ( self . trainer , type ) :
self . trainer = dict ( )
if isinstance ( self . inference , type ) :
self . inference = dict ( )
if isinstance ( self . optimizations , type ) :
self . optimizations = dict ( )
2023-08-16 02:58:16 +00:00
self . dataset = Dataset ( * * self . dataset )
2024-04-18 01:39:35 +00:00
self . dataset . training = [ Path ( dir ) for dir in self . dataset . training ]
self . dataset . validation = [ Path ( dir ) for dir in self . dataset . validation ]
self . dataset . noise = [ Path ( dir ) for dir in self . dataset . noise ]
2024-04-16 15:02:31 +00:00
if self . models is not None :
self . model = Model ( * * next ( iter ( self . models ) ) )
else :
self . model = Model ( * * self . model )
2024-04-18 01:39:35 +00:00
2023-08-16 02:58:16 +00:00
self . hyperparameters = Hyperparameters ( * * self . hyperparameters )
2024-05-03 01:08:59 +00:00
2023-08-16 02:58:16 +00:00
self . evaluation = Evaluation ( * * self . evaluation )
2024-05-03 01:08:59 +00:00
2023-08-16 02:58:16 +00:00
self . trainer = Trainer ( * * self . trainer )
2024-05-03 01:08:59 +00:00
2024-04-18 01:39:35 +00:00
if not isinstance ( self . trainer . deepspeed , type ) :
self . trainer . deepspeed = DeepSpeed ( * * self . trainer . deepspeed )
2024-05-03 01:08:59 +00:00
2023-08-16 02:58:16 +00:00
self . inference = Inference ( * * self . inference )
2024-05-03 01:08:59 +00:00
if self . bitsandbytes is not None :
self . optimizations = Optimizations ( * * self . bitsandbytes )
else :
self . optimizations = Optimizations ( * * self . optimizations )
2023-08-16 02:58:16 +00:00
2024-05-10 01:28:20 +00:00
if self . hyperparameters . scheduler_type and not self . hyperparameters . scheduler :
self . hyperparameters . scheduler = self . hyperparameters . scheduler_type
self . hyperparameters . scheduler_type = " "
# do not combine the two
if self . hyperparameters . scheduler == " schedulefree " and self . optimizations . dadaptation :
self . hyperparameters . scheduler = " "
if self . hyperparameters . scheduler == " " :
self . hyperparameters . torch_scheduler = True
2024-05-11 14:50:54 +00:00
if self . dataset . prompt_duration != 0 :
self . dataset . prompt_duration_range = [ self . dataset . prompt_duration , self . dataset . prompt_duration ]
2024-05-11 21:31:05 +00:00
if self . trainer . backend == " local " and self . distributed :
self . trainer . ddp = True
2024-05-25 16:07:52 +00:00
if self . inference . audio_backend != " " and self . audio_backend == " " :
self . audio_backend = self . inference . audio_backend
2024-05-11 21:31:05 +00:00
2024-06-04 02:28:49 +00:00
if self . trainer . activation_checkpointing is not None :
self . trainer . gradient_checkpointing = self . trainer . activation_checkpointing
2024-04-30 03:14:01 +00:00
# Preserves the old behavior
class NaiveTokenizer :
def get_vocab ( self ) :
"""
if cfg . dataset . use_hdf5 and ' symmap ' in cfg . hdf5 :
return json . loads ( cfg . hdf5 [ ' symmap ' ] . asstr ( ) [ ( ) ] )
"""
return { ' <s> ' : 1 , ' </s> ' : 2 , ' ' : 3 , ' . ' : 4 , ' , ' : 5 , ' ! ' : 6 , ' ? ' : 7 , ' p ' : 7 , ' iː ' : 8 , ' ɚ ' : 9 , ' ˌ ' : 10 , ' dˌ ' : 11 , ' mˌ ' : 12 , ' d ' : 13 , ' ɹ ' : 14 , ' tˈ ' : 15 , ' pˌ ' : 16 , ' uː ' : 17 , ' l ' : 18 , ' æ ' : 19 , ' ɛ ' : 20 , ' ɪ ' : 21 , ' j ' : 22 , ' ʊ ' : 23 , ' t ' : 24 , ' n ' : 25 , ' v ' : 26 , ' a ' : 27 , ' o ' : 28 , ' ŋ ' : 29 , ' w ' : 30 , ' ʌ ' : 31 , ' hˈ ' : 32 , ' ɡ ˈ ' : 33 , ' ə ' : 34 , ' θˈ ' : 35 , ' dˈ ' : 36 , ' wˌ ' : 37 , ' h ' : 38 , ' z ' : 39 , ' k ' : 40 , ' ð ' : 41 , ' ɡˌ ' : 42 , ' ˈ ' : 43 , ' fˈ ' : 44 , ' i ' : 45 , ' s ' : 46 , ' ʃ ' : 47 , ' wˈ ' : 48 , ' ðˈ ' : 49 , ' ɹˈ ' : 50 , ' lˈ ' : 51 , ' ɡ ' : 52 , ' oː ' : 53 , ' mˈ ' : 54 , ' e ' : 55 , ' ɑ ː ' : 56 , ' nˈ ' : 57 , ' m ' : 58 , ' θˌ ' : 59 , ' sˈ ' : 60 , ' f ' : 61 , ' ɔː ' : 62 , ' hˌ ' : 63 , ' b ' : 64 , ' jˈ ' : 65 , ' ɐ ' : 66 , ' ʒˈ ' : 67 , ' θ ' : 68 , ' bˈ ' : 69 , ' ɾ ' : 70 , ' ɜː ' : 71 , ' ʌˈ ' : 72 , ' ʃˌ ' : 73 , ' bˌ ' : 74 , ' kˈ ' : 75 , ' ɔ ' : 76 , ' zˈ ' : 77 , ' ᵻ ' : 78 , ' kˌ ' : 79 , ' vˈ ' : 80 , ' fˌ ' : 81 , ' ʒ ' : 82 , ' ʃˈ ' : 83 , ' ɹˌ ' : 84 , ' tˌ ' : 85 , ' pˈ ' : 86 , ' ðˌ ' : 87 , ' sˌ ' : 88 , ' nˌ ' : 89 , ' lˌ ' : 90 , ' ̩ ' : 91 , ' ʔ ' : 92 , ' vˌ ' : 93 , ' ɪ ˈ ' : 94 , ' " ' : 95 , ' ɪˌ ' : 96 , ' ʒˌ ' : 97 , ' uː ˌ ' : 98 , ' ʊˈ ' : 99 , ' jˌ ' : 100 , ' uː ˈ ' : 101 , ' iː ˈ ' : 102 , ' zˌ ' : 103 , ' .ˈ ' : 104 , ' … ' : 105 , ' ŋˌ ' : 106 , ' ɐˌ ' : 107 , ' —ˈ ' : 108 , ' iˌ ' : 109 , ' iː ˌ ' : 110 , ' ɛː ' : 111 , ' ) ' : 112 , ' )ˈ ' : 113 , ' ( ' : 114 , ' u ' : 115 , ' - ' : 116 , ' ɖˈ ' : 117 , ' iˈ ' : 118 , ' ʰˈ ' : 119 , ' ɟˈ ' : 120 , ' ̃ ' : 121 , ' eː ' : 122 , ' ɾˈ ' : 123 , ' r ' : 124 , ' ʰ ' : 125 , ' -ˌ ' : 126 , ' ɫ ' : 127 , ' q ' : 128 , ' — ' : 129 , ' ʊˌ ' : 130 , ' aː ' : 131 , ' cˈ ' : 132 , ' …ˈ ' : 133 , ' c ' : 134 , ' ɳ ' : 135 , ' ɐˈ ' : 136 , ' x ' : 137 , ' ʔˌ ' : 138 , ' .ˌ ' : 139 , ' ɑ ' : 140 , ' ?ˈ ' : 141 , ' ̩ˈ ' : 142 , ' " ˈ ' : 143 , ' ,ˈ ' : 144 , ' ŋˈ ' : 145 , ' əˌ ' : 146 , ' !ˈ ' : 147 , ' " ˌ ' : 148 , ' ?ˌ ' : 149 , ' ,ˌ ' : 150 , ' —ˌ ' : 151 , ' ̩ˌ ' : 152 , ' əˈ ' : 153 , ' !ˌ ' : 154 , ' ɬ ' : 155 , ' ʲ ' : 156 , ' ¡ ' : 157 , ' ɯ ' : 158 , ' qˌ ' : 159 , ' ʑ ' : 160 , ' ʑˈ ' : 161 , ' ¿ ' : 162 , ' ɑ ː ˈ ' : 163 , ' iː ː ' : 164 , ' ɛˈ ' : 165 , ' ¡ˈ ' : 166 , ' æˈ ' : 167 , ' ç ' : 168 , ' ɾˌ ' : 169 , ' ᵻˈ ' : 170 , ' xˈ ' : 171 , ' ɔːˈ ' : 172 , ' ; ' : 173 , ' ɬˌ ' : 174 , ' : ' : 175 , ' ʔ ˈ ' : 176 , ' ɑːˌ ' : 177 , ' ɬˈ ' : 178 , ' ” ' : 179 , ' “ ' : 180 , ' “ˈ ' : 181 , ' “ˌ ' : 182 , ' ;ˈ ' : 183 , ' ;ˌ ' : 184 , ' :ˈ ' : 185 , ' 1 ' : 186 , ' rˈ ' : 187 , ' qˈ ' : 188 , ' ᵻˌ ' : 189 , ' ä ' : 190 , ' ̞ˌ ' : 191 , ' ̞ ' : 192 , ' ũˌ ' : 193 , ' ʑˌ ' : 194 , ' ᵝ ' : 195 , ' ɽ ' : 196 , ' ʲˌ ' : 197 , ' ᵝˌ ' : 198 , ' ũ ' : 199 , ' ũˈ ' : 200 , ' äˌ ' : 201 , ' ɕ ' : 202 , ' ɕˌ ' : 203 , ' ɽˌ ' : 204 , ' çˌ ' : 205 , ' …ˌ ' : 206 , ' ̞ˈ ' : 207 , ' äˈ ' : 208 , ' ɽˈ ' : 209 , ' ɸˌ ' : 210 , ' ɴ ' : 211 , ' ɸˈ ' : 212 , ' ɕˈ ' : 213 , ' ɸ ' : 214 , ' ᵝˈ ' : 215 , ' ʲˈ ' : 216 , ' ĩ ' : 217 , ' çˈ ' : 218 , ' ĩˌ ' : 219 , ' oˌ ' : 220 , ' eˈ ' : 221 , ' ʍ ' : 222 , ' eˌ ' : 223 , ' uˌ ' : 224 , ' ʍˌ ' : 225 , ' uˈ ' : 226 , ' oˈ ' : 227 , ' aˈ ' : 228 }
def encode ( self , s ) :
symmap = self . get_vocab ( )
phones = " " . join ( list ( s ) )
# do merge
for merge in [ " \u02C8 " , " \u02CC " , " \u02D0 " ] :
phones = phones . replace ( f ' { merge } ' , merge )
phones = phones . split ( " " )
# cleanup
phones = [ p for i , p in enumerate ( phones ) if p not in [ " " ] or ( p in [ " " ] and p != phones [ i - 1 ] ) ]
# add bos / eos
phones = [ " <s> " ] + [ " " if not p else p for p in phones ] + [ " </s> " ]
# tokenize
return [ * map ( symmap . get , phones ) ]
2023-08-02 21:53:35 +00:00
cfg = Config . from_cli ( )
2023-08-14 03:07:45 +00:00
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
try :
2023-08-16 02:58:16 +00:00
cfg . format ( )
2023-08-14 03:56:28 +00:00
if cfg . dataset . use_hdf5 :
2023-08-19 14:50:07 +00:00
cfg . load_hdf5 ( )
2023-08-14 03:07:45 +00:00
except Exception as e :
2024-05-12 18:02:15 +00:00
cfg . dataset . use_hdf5 = False
2024-04-21 19:49:18 +00:00
print ( " Error while parsing config YAML: " , e )
2023-08-14 03:07:45 +00:00
pass
2023-08-04 01:26:36 +00:00
2024-04-21 19:49:18 +00:00
try :
from transformers import PreTrainedTokenizerFast
cfg . tokenizer = ( cfg . relpath if cfg . cfg_path is not None else Path ( " ./data/ " ) ) / cfg . tokenizer
cfg . tokenizer = PreTrainedTokenizerFast ( tokenizer_file = str ( cfg . tokenizer ) )
except Exception as e :
2024-04-30 03:14:01 +00:00
cfg . tokenizer = NaiveTokenizer ( )
2024-04-21 19:49:18 +00:00
print ( " Error while parsing tokenizer: " , e )
pass
2023-08-02 21:53:35 +00:00
if __name__ == " __main__ " :
2023-08-14 03:07:45 +00:00
print ( cfg )