2023-08-02 21:53:35 +00:00
import copy
import diskcache
import h5py
import json
import os
import subprocess
import sys
import time
2023-08-14 03:07:45 +00:00
import torch
2023-08-02 21:53:35 +00:00
from dataclasses import asdict , dataclass
from dataclasses import dataclass , field
from functools import cached_property
from pathlib import Path
from omegaconf import OmegaConf
2023-08-14 03:56:28 +00:00
from . utils . distributed import world_size
2023-08-02 21:53:35 +00:00
@dataclass ( )
class _Config :
cfg_path : str | None = None
@property
def relpath ( self ) :
return Path ( self . cfg_path )
2023-08-27 00:53:23 +00:00
@property
def cache_dir ( self ) :
return self . relpath / " .cache "
2023-08-02 21:53:35 +00:00
@property
def ckpt_dir ( self ) :
return self . relpath / " ckpt "
@property
def log_dir ( self ) :
return self . relpath / " logs " / str ( self . start_time )
@cached_property
def start_time ( self ) :
return int ( time . time ( ) )
@cached_property
def git_commit ( self ) :
try :
cmd = " git rev-parse HEAD "
return subprocess . check_output ( cmd . split ( ) ) . decode ( " utf8 " ) . strip ( )
except :
return " "
@cached_property
def git_status ( self ) :
try :
cmd = " git status "
return subprocess . check_output ( cmd . split ( ) ) . decode ( " utf8 " ) . strip ( )
except :
return " "
def dumps ( self ) :
data = { k : getattr ( self , k ) for k in dir ( self ) if not k . startswith ( " __ " ) }
data = { k : v for k , v in data . items ( ) if not callable ( v ) }
return json . dumps ( data , indent = 2 , default = str )
def dump ( self , path = None ) :
if path is None :
path = self . log_dir / " cfg.json "
path . parent . mkdir ( parents = True , exist_ok = True )
with open ( path , " w " ) as f :
f . write ( self . dumps ( ) )
@staticmethod
def _is_cfg_argv ( s ) :
return " = " in s and " -- " not in s
@classmethod
def from_yaml ( cls , yaml_path ) :
return cls . from_cli ( [ f ' yaml= " { yaml_path } " ' ] )
@classmethod
def from_cli ( cls , args = sys . argv ) :
cli_cfg = OmegaConf . from_cli ( [ s for s in args if cls . _is_cfg_argv ( s ) ] )
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
sys . argv = [ s for s in sys . argv if not cls . _is_cfg_argv ( s ) ]
if cli_cfg . get ( " help " ) :
print ( f " Configurable hyperparameters with their default values: " )
print ( json . dumps ( asdict ( cls ( ) ) , indent = 2 , default = str ) )
exit ( )
if " yaml " in cli_cfg :
yaml_cfg = OmegaConf . load ( cli_cfg . yaml )
yaml_path = Path ( cli_cfg . yaml ) . absolute ( )
cfg_path = Path ( * yaml_path . relative_to ( Path . cwd ( ) ) . parts [ : - 1 ] )
cfg_path = cfg_path . with_suffix ( " " )
cfg_path = f ' ./ { cfg_path } '
yaml_cfg . setdefault ( " cfg_path " , cfg_path )
cli_cfg . pop ( " yaml " )
else :
yaml_cfg = { }
merged = OmegaConf . merge ( yaml_cfg , cli_cfg )
return cls ( * * dict ( merged ) )
def __repr__ ( self ) :
return str ( self )
def __str__ ( self ) :
return self . dumps ( )
@dataclass ( )
class Dataset :
training : list [ Path ] = field ( default_factory = lambda : [ ] )
validation : list [ Path ] = field ( default_factory = lambda : [ ] )
2023-08-19 04:55:40 +00:00
noise : list [ Path ] = field ( default_factory = lambda : [ ] )
2023-08-02 21:53:35 +00:00
temp : list [ Path ] = field ( default_factory = lambda : [ ] )
speaker_name_getter : str = " lambda p: f ' {p.parts[-3]} _ {p.parts[-2]} ' "
hdf5_name : str = " data.h5 "
use_hdf5 : bool = False
2023-08-27 00:53:23 +00:00
use_metadata : bool = False
2023-08-18 19:47:48 +00:00
hdf5_flag : str = " a "
2023-08-02 21:53:35 +00:00
validate : bool = True
workers : int = 8
cache : bool = True
phones_range : list [ int ] = field ( default_factory = lambda : [ 4 , 256 ] )
duration_range : list [ float ] = field ( default_factory = lambda : [ 1.0 , 12.0 ] )
random_utterance : float = 1.0
max_prompts : int = 3
prompt_duration : float = 3.0
2023-08-17 00:39:21 +00:00
sample_type : str = " path " # path | speaker
2023-08-19 05:16:08 +00:00
tasks_list : list [ str ] = field ( default_factory = lambda : [ " tts " ] )
2023-08-17 00:39:21 +00:00
2023-08-27 00:53:23 +00:00
@property
def min_phones ( self ) :
return self . phones_range [ 0 ]
@property
def max_phones ( self ) :
return self . phones_range [ 1 ]
@property
def min_duration ( self ) :
return self . duration_range [ 0 ]
@property
def max_duration ( self ) :
return self . duration_range [ 1 ]
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Model :
name : str = " "
2023-09-01 22:19:34 +00:00
size : str | float | dict = " full "
2023-08-02 21:53:35 +00:00
resp_levels : int = 1
2023-08-19 01:58:07 +00:00
prom_levels : int = 8
2023-08-27 03:00:43 +00:00
tasks : int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
2023-08-02 21:53:35 +00:00
arch_type : str = " transformer "
2023-08-27 17:26:12 +00:00
training : bool = True
2023-09-04 03:46:08 +00:00
interleave : bool = False
2023-09-07 23:19:51 +00:00
frozen_params : list [ str ] = field ( default_factory = lambda : [ ] )
2023-08-02 21:53:35 +00:00
@property
def full_name ( self ) :
name = [ self . name ]
2023-09-07 14:14:03 +00:00
if self . size != " full " and isinstance ( self . size , str ) :
2023-08-02 21:53:35 +00:00
name . append ( self . size )
if self . arch_type != " transformer " :
name . append ( self . arch_type . replace ( " / " , " - " ) )
2023-09-04 03:46:08 +00:00
if self . interleave :
name . append ( " interleaved " )
2023-08-19 01:58:07 +00:00
name . append ( f ' { cfg . models . prom_levels } ' )
2023-08-02 21:53:35 +00:00
return " - " . join ( name )
@property
def tokens ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " tokens " ) :
return self . size [ ' tokens ' ]
2023-08-02 21:53:35 +00:00
return 1024
@property
def dim ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " dim " ) :
return self . size [ ' dim ' ]
if isinstance ( self . size , float ) :
return math . floor ( 1024 * self . size )
2023-08-02 21:53:35 +00:00
if self . size == " quarter " :
return 256
if self . size == " half " :
return 512
2023-09-02 02:33:51 +00:00
return 1024
2023-08-02 21:53:35 +00:00
@property
def heads ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " heads " ) :
return self . size [ ' heads ' ]
if isinstance ( self . size , float ) :
return math . floor ( 16 * self . size )
2023-08-02 21:53:35 +00:00
if self . size == " quarter " :
return 4
if self . size == " half " :
return 8
2023-09-02 02:33:51 +00:00
return 16
2023-08-02 21:53:35 +00:00
@property
def layers ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " layers " ) :
return self . size [ ' layers ' ]
2023-09-02 02:33:51 +00:00
if self . size == " double " :
return 24
2023-08-02 21:53:35 +00:00
return 12
2023-09-05 20:38:21 +00:00
@property
def activation_checkpointing ( self ) :
return cfg . trainer . activation_checkpointing
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Models :
2023-08-19 20:06:33 +00:00
_max_levels : int = 0
2023-08-27 17:26:12 +00:00
_prom_levels : int = 1
2023-08-19 20:06:33 +00:00
2023-08-02 21:53:35 +00:00
_models : list [ Model ] = field ( default_factory = lambda : [
2023-09-04 03:46:08 +00:00
Model ( name = " ar " , resp_levels = 1 , prom_levels = 8 , tasks = 8 , training = True , interleave = False ) ,
Model ( name = " nar " , resp_levels = 7 , prom_levels = 8 , tasks = 8 , training = True , interleave = False ) ,
2023-08-02 21:53:35 +00:00
] )
def get ( self , name = None ) :
if not name :
return [ Model ( * * model ) for model in self . _models ]
for model in self . _models :
if model . name == name :
return model
raise ValueError
@property
def ar ( self ) :
return self . get ( " ar " )
2023-09-06 23:58:35 +00:00
@property
def ar_nar ( self ) :
return self . get ( " ar+nar " )
2023-08-02 21:53:35 +00:00
@property
def nar ( self ) :
return self . get ( " nar " )
@property
2023-08-19 01:58:07 +00:00
def prom_levels ( self ) :
2023-08-27 17:26:12 +00:00
prom_levels = self . _prom_levels
2023-08-19 01:58:07 +00:00
for model in self . _models :
prom_levels = max ( prom_levels , model . prom_levels )
return prom_levels
2023-08-02 21:53:35 +00:00
2023-08-19 01:58:07 +00:00
@property
def tasks ( self ) :
tasks = 1
for model in self . _models :
tasks = max ( tasks , model . tasks )
return tasks
2023-08-19 20:06:33 +00:00
@property
def max_levels ( self ) :
return self . _max_levels if self . _max_levels > 0 else self . prom_levels
2023-08-19 01:58:07 +00:00
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Hyperparameters :
batch_size : int = 8
gradient_accumulation_steps : int = 32
gradient_clipping : int = 100
optimizer : str = " Adamw "
2023-09-07 14:14:03 +00:00
torch_optimizer : bool = False
2023-09-06 23:58:35 +00:00
optimizer_params : dict = field ( default_factory = lambda : { } )
2023-08-02 21:53:35 +00:00
learning_rate : float = 3.25e-4
scheduler_type : str = " "
scheduler_params : dict = field ( default_factory = lambda : { } )
@dataclass ( )
class Evaluation :
batch_size : int = 64
frequency : int = 250
size : int = 64
steps : int = 500
2023-08-04 01:26:36 +00:00
ar_temperature : float = 1.0
nar_temperature : float = 0.2
2023-08-27 17:26:12 +00:00
load_disabled_engines : bool = True
2023-08-04 01:26:36 +00:00
@dataclass ( )
class DeepSpeed :
zero_optimization_level : int = 0
use_compression_training : bool = False
2023-08-19 01:58:07 +00:00
compression_bits : int = 8
2023-08-04 01:26:36 +00:00
2023-08-19 20:06:33 +00:00
@cached_property
def ds_cfg ( self ) :
2023-08-04 01:26:36 +00:00
scheduler_params = { }
for k in cfg . hyperparameters . scheduler_params :
scheduler_params [ k ] = cfg . hyperparameters . scheduler_params [ k ]
if cfg . hyperparameters . scheduler_type == " WarmupDecayLR " and ' total_num_steps ' not in scheduler_params :
scheduler_params [ ' total_num_steps ' ] = cfg . trainer . iterations
ds_cfg = {
" train_micro_batch_size_per_gpu " : cfg . hyperparameters . batch_size ,
" gradient_accumulation_steps " : cfg . hyperparameters . gradient_accumulation_steps ,
" optimizer " : {
" type " : cfg . hyperparameters . optimizer ,
" params " : {
" lr " : cfg . hyperparameters . learning_rate ,
}
2023-09-07 14:14:03 +00:00
} if not cfg . hyperparameters . torch_optimizer else None ,
2023-08-04 01:26:36 +00:00
" scheduler " : {
" type " : cfg . hyperparameters . scheduler_type ,
" params " : scheduler_params ,
} if cfg . hyperparameters . scheduler_type != " " else None ,
" gradient_clipping " : cfg . hyperparameters . gradient_clipping ,
" fp16 " : {
" enabled " : True ,
" auto_cast " : True ,
2023-09-02 01:58:29 +00:00
} if cfg . trainer . weight_dtype . lower ( ) == " float16 " and not cfg . trainer . amp else None ,
2023-08-04 01:26:36 +00:00
" bf16 " : {
2023-09-02 01:58:29 +00:00
" enabled " : cfg . trainer . weight_dtype . lower ( ) == " bfloat16 " and not cfg . trainer . amp
2023-08-04 01:26:36 +00:00
} ,
" compression_training " : {
" weight_quantization " : {
" shared_parameters " : {
" enabled " : True ,
" quantizer_kernel " : True ,
" schedule_offset " : 0 ,
" quantize_groups " : 64 ,
" quantize_verbose " : True ,
" quantization_type " : " symmetric " ,
" rounding " : " nearest " ,
" quantize_weight_in_forward " : True ,
" fp16_mixed_quantize " : {
" enabled " : False ,
" quantize_change_ratio " : 1
}
} ,
" different_groups " : {
" wq1 " : {
" params " : {
2023-08-19 01:58:07 +00:00
" start_bits " : self . compression_bits ,
" target_bits " : self . compression_bits ,
2023-08-04 01:26:36 +00:00
" quantization_period " : 0
} ,
2023-08-19 01:58:07 +00:00
" modules " : [
2023-08-19 03:22:13 +00:00
" blocks " , # for transformer-based models
" retnet " , # for RetNets-based models
2023-08-19 01:58:07 +00:00
]
2023-08-04 01:26:36 +00:00
}
}
} ,
} if self . use_compression_training else None ,
" zero_optimization " : {
" stage " : self . zero_optimization_level ,
" contiguous_gradients " : True ,
" overlap_comm " : True ,
" reduce_scatter " : True ,
" reduce_bucket_size " : 5e8 ,
" allgather_bucket_size " : 5e8 ,
" sub_group_size " : 5e8 ,
" round_robin_gradients " : True ,
" offload_optimizer " : {
" device " : " cpu " ,
" pin_memory " : True
} ,
" offload_param " : {
" device " : " cpu " ,
" pin_memory " : True
2023-08-16 02:58:16 +00:00
} ,
" zero_quantized_weights " : self . use_compression_training ,
" zero_hpz_partition_size " : world_size ( ) ,
" zero_quantized_gradients " : self . use_compression_training ,
2023-08-04 01:26:36 +00:00
} if self . zero_optimization_level > 0 else None ,
" comms_logger " : {
" enabled " : False
}
}
null_keys = [ k for k in ds_cfg if not ds_cfg [ k ] ]
for k in null_keys :
del ds_cfg [ k ]
2023-08-19 20:06:33 +00:00
if os . path . exists ( " ./data/ds_config.json " ) :
ds_cfg . update ( json . load ( open ( " ./data/ds_config.json " , " r " , encoding = " utf-8 " ) ) )
2023-08-04 01:26:36 +00:00
return ds_cfg
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Trainer :
iterations : int = 100_000
save_tag : str = " step "
load_tag : str | None = None
save_on_oom : bool = True
save_on_quit : bool = True
2023-08-23 21:43:03 +00:00
export_on_save : bool = False
export_on_quit : bool = False
2023-08-02 21:53:35 +00:00
save_frequency : int = 100
2023-08-17 01:12:12 +00:00
keep_last_checkpoints : int = 0
2023-08-02 21:53:35 +00:00
load_state_dict : bool = False
load_states : bool = True
strict_loading : bool = True
2023-08-20 18:42:18 +00:00
load_module_only : bool = False
2023-08-02 21:53:35 +00:00
restart_step_count : bool = False
2023-09-05 20:38:21 +00:00
activation_checkpointing : bool = True
2023-08-02 21:53:35 +00:00
aggressive_optimizations : bool = False
2023-08-04 01:36:19 +00:00
check_for_oom : bool = True
2023-08-02 21:53:35 +00:00
gc_mode : str | None = None
2023-09-02 01:58:29 +00:00
load_disabled_engines : bool = False
2023-08-02 21:53:35 +00:00
weight_dtype : str = " float16 "
2023-09-02 01:58:29 +00:00
amp : bool = False
2023-08-02 21:53:35 +00:00
2023-09-02 01:58:29 +00:00
backend : str = " local "
2023-08-04 01:26:36 +00:00
deepspeed : DeepSpeed = field ( default_factory = lambda : DeepSpeed )
2023-08-02 21:53:35 +00:00
2023-08-05 03:22:15 +00:00
@cached_property
def dtype ( self ) :
if self . weight_dtype == " float16 " :
return torch . float16
2023-08-14 03:07:45 +00:00
if self . weight_dtype == " bfloat16 " :
2023-08-05 03:22:15 +00:00
return torch . bfloat16
return torch . float32
2023-08-02 23:36:26 +00:00
@dataclass ( )
class Inference :
2023-08-21 02:36:02 +00:00
weight_dtype : str = " float32 "
2023-09-02 01:58:29 +00:00
amp : bool = False
2023-08-21 02:36:02 +00:00
2023-08-19 04:55:40 +00:00
normalize : bool = False # do NOT enable this unless you know exactly what you're doing
2023-08-02 23:36:26 +00:00
use_vocos : bool = True
2023-09-02 01:58:29 +00:00
recurrent_chunk_size : int = 0
recurrent_forward : bool = False
2023-08-21 02:36:02 +00:00
@cached_property
def dtype ( self ) :
if self . weight_dtype == " float16 " :
return torch . float16
if self . weight_dtype == " bfloat16 " :
return torch . bfloat16
return torch . float32
2023-08-02 23:36:26 +00:00
@dataclass ( )
class BitsAndBytes :
enabled : bool = False
injects : bool = False
2023-08-19 20:06:33 +00:00
linear : bool = True
embedding : bool = True
2023-08-02 23:36:26 +00:00
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Config ( _Config ) :
device : str = " cuda "
2023-09-02 01:58:29 +00:00
mode : str = " training " # "inferencing"
2023-08-02 21:53:35 +00:00
dataset : Dataset = field ( default_factory = lambda : Dataset )
models : Models = field ( default_factory = lambda : Models )
hyperparameters : Hyperparameters = field ( default_factory = lambda : Hyperparameters )
evaluation : Evaluation = field ( default_factory = lambda : Evaluation )
trainer : Trainer = field ( default_factory = lambda : Trainer )
2023-08-02 23:36:26 +00:00
inference : Inference = field ( default_factory = lambda : Inference )
bitsandbytes : BitsAndBytes = field ( default_factory = lambda : BitsAndBytes )
2023-08-02 21:53:35 +00:00
@property
def sample_rate ( self ) :
return 24_000
2023-08-14 03:56:28 +00:00
@property
def distributed ( self ) :
return world_size ( ) > 1
2023-08-02 21:53:35 +00:00
@cached_property
def get_spkr ( self ) :
return eval ( self . dataset . speaker_name_getter )
@cached_property
def diskcache ( self ) :
2023-08-23 16:02:15 +00:00
if self . cfg_path is not None and self . dataset . cache :
2023-08-02 21:53:35 +00:00
return diskcache . Cache ( self . cache_dir ) . memoize
return lambda : lambda x : x
def load_yaml ( self , config_path ) :
tmp = Config . from_yaml ( config_path )
self . __dict__ . update ( tmp . __dict__ )
2023-08-19 14:50:07 +00:00
def load_hdf5 ( self , write = False ) :
if hasattr ( self , ' hdf5 ' ) :
self . hdf5 . close ( )
if self . distributed :
self . dataset . hdf5_flag = " r "
try :
self . hdf5 = h5py . File ( f ' { self . cfg_path } / { self . dataset . hdf5_name } ' , ' a ' if write else self . dataset . hdf5_flag ) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e :
print ( " Error while opening HDF5 file: " , f ' { self . cfg_path } / { self . dataset . hdf5_name } ' , str ( e ) )
self . dataset . use_hdf5 = False
2023-08-16 02:58:16 +00:00
def format ( self ) :
self . dataset = Dataset ( * * self . dataset )
self . models = Models ( * * self . models )
self . hyperparameters = Hyperparameters ( * * self . hyperparameters )
self . evaluation = Evaluation ( * * self . evaluation )
self . trainer = Trainer ( * * self . trainer )
self . inference = Inference ( * * self . inference )
self . bitsandbytes = BitsAndBytes ( * * self . bitsandbytes )
self . trainer . deepspeed = DeepSpeed ( * * self . trainer . deepspeed )
2023-08-27 03:00:43 +00:00
self . dataset . training = [ Path ( dir ) for dir in self . dataset . training ]
self . dataset . validation = [ Path ( dir ) for dir in self . dataset . validation ]
self . dataset . noise = [ Path ( dir ) for dir in self . dataset . noise ]
2023-08-16 02:58:16 +00:00
2023-08-02 21:53:35 +00:00
cfg = Config . from_cli ( )
2023-08-14 03:07:45 +00:00
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
try :
2023-08-16 02:58:16 +00:00
cfg . format ( )
2023-08-14 03:56:28 +00:00
# cached_property stopped working...
if cfg . dataset . use_hdf5 :
2023-08-19 14:50:07 +00:00
cfg . load_hdf5 ( )
2023-08-14 03:56:28 +00:00
2023-08-27 00:53:23 +00:00
2023-08-14 03:07:45 +00:00
except Exception as e :
pass
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
if __name__ == " __main__ " :
2023-08-14 03:07:45 +00:00
print ( cfg )