import copy import diskcache import h5py import json import os import subprocess import sys import time import torch from dataclasses import asdict, dataclass, field from functools import cached_property from pathlib import Path from omegaconf import OmegaConf from .utils.distributed import world_size @dataclass() class _Config: cfg_path: str | None = None @property def relpath(self): return Path(self.cfg_path) @property def cache_dir(self): return self.relpath / ".cache" @property def ckpt_dir(self): return self.relpath / "ckpt" @property def log_dir(self): return self.relpath / "logs" / str(self.start_time) @cached_property def start_time(self): return int(time.time()) @cached_property def git_commit(self): try: cmd = "git rev-parse HEAD" return subprocess.check_output(cmd.split()).decode("utf8").strip() except: return "" @cached_property def git_status(self): try: cmd = "git status" return subprocess.check_output(cmd.split()).decode("utf8").strip() except: return "" def dumps(self): data = {k: getattr(self, k) for k in dir(self) if not k.startswith("__")} data = {k: v for k, v in data.items() if not callable(v)} return json.dumps(data, indent=2, default=str) def dump(self, path=None): if path is None: path = self.log_dir / "cfg.json" path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: f.write(self.dumps()) @staticmethod def _is_cfg_argv(s): return "=" in s and "--" not in s @classmethod def from_yaml( cls, yaml_path ): return cls.from_cli( [f'yaml="{yaml_path}"'] ) @classmethod def from_cli(cls, args=sys.argv): cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)]) # Replace argv to ensure there are no omegaconf options, for compatibility with argparse. sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)] if cli_cfg.get("help"): print(f"Configurable hyperparameters with their default values:") print(json.dumps(asdict(cls()), indent=2, default=str)) exit() if "yaml" in cli_cfg: yaml_cfg = OmegaConf.load(cli_cfg.yaml) yaml_path = Path(cli_cfg.yaml).absolute() cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1]) cfg_path = cfg_path.with_suffix("") cfg_path = f'./{cfg_path}' yaml_cfg.setdefault("cfg_path", cfg_path) cli_cfg.pop("yaml") else: yaml_cfg = {} merged = OmegaConf.merge(yaml_cfg, cli_cfg) return cls(**dict(merged)) def __repr__(self): return str(self) def __str__(self): return self.dumps() @dataclass() class Dataset: training: list[Path] = field(default_factory=lambda: []) validation: list[Path] = field(default_factory=lambda: []) noise: list[Path] = field(default_factory=lambda: []) temp: list[Path] = field(default_factory=lambda: []) speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups hdf5_name: str = "data.h5" use_hdf5: bool = False use_metadata: bool = False hdf5_flag: str = "a" validate: bool = True workers: int = 8 cache: bool = True phones_range: list[int] = field(default_factory=lambda: [4, 256]) duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) min_utterances: int = 2 random_utterance: float = 1.0 max_prompts: int = 3 prompt_duration: float = 3.0 max_resps: int = 1 p_resp_append: float = 1.0 sample_type: str = "path" # path | speaker tasks_list: list[str] = field(default_factory=lambda: ["tts"]) @property def min_phones(self): return self.phones_range[0] @property def max_phones(self): return self.phones_range[1] @property def min_duration(self): return self.duration_range[0] @property def max_duration(self): return self.duration_range[1] @dataclass() class Model: _max_levels: int = 0 _embeddings: str | None = None name: str = "" # vanity name for the model version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding size: str | dict = "full" # preset string or explicitly defined dimensionality resp_levels: int = 1 # RVQ-bin levels this model targets for outputs prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") langs: int = 1 # defined languages tones: int = 1 # defined tones experts: int = 1 arch_type: str = "retnet" # or "transformer"" training: bool = True # unneeded now interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results) p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training attention: str = "eager" # or flash_attention_2 def get(self, name=None): return [ self ] if not name or self.name == name else [] @property def max_levels(self): return self._max_levels if self._max_levels > 0 else self.prom_levels @property # required for fp8 as the lengths needs to be divisible by 8 def input_alignment(self): return 8 if cfg.fp8.enabled else 0 @property def full_name(self): name = [ self.name ] if isinstance(self.size, dict): if hasattr(self.size, "label") and self.size['label']: name.append(f"{self.size['label']}") elif isinstance(self.size, str) and self.size != "full": name.append(self.size) if self.arch_type != "transformer": if self.experts > 1: name.append(f'{self.experts}x'+self.arch_type.replace("/", "-")) else: name.append(self.arch_type.replace("/", "-")) if cfg.bitsandbytes.bitnet: name.append("bitnet") if self.interleave: name.append("interleaved") else: name.append(f'{cfg.model.prom_levels}') return "-".join(name) @property def tokens(self): if isinstance(self.size, dict) and hasattr(self.size, "tokens"): return self.size['tokens'] return 1024 @property def dim(self): if isinstance(self.size, dict) and hasattr(self.size, "dim"): return self.size['dim'] if isinstance(self.size, float): return math.floor(1024 * self.size) if self.size == "quarter": return 256 if self.size == "half": return 512 return 1024 @property def heads(self): if isinstance(self.size, dict) and hasattr(self.size, "heads"): return self.size['heads'] if isinstance(self.size, float): return math.floor(16 * self.size) if self.size == "quarter": return 4 if self.size == "half": return 8 return 16 @property def layers(self): if isinstance(self.size, dict) and hasattr(self.size, "layers"): return self.size['layers'] if self.size == "double": return 24 return 12 @property def activation_checkpointing(self): return cfg.trainer.activation_checkpointing @dataclass() class Hyperparameters: batch_size: int = 8 gradient_accumulation_steps: int = 32 gradient_clipping: int | float = 100 optimizer: str = "Adamw" torch_optimizer: bool = False optimizer_params: dict = field(default_factory=lambda: {}) learning_rate: float = 3.25e-4 scheduler_type: str = "" scheduler_params: dict = field(default_factory=lambda: {}) @dataclass() class Evaluation: batch_size: int = 64 frequency: int = 250 size: int = 64 steps: int = 500 ar_temperature: float = 1.0 nar_temperature: float = 0.2 load_disabled_engines: bool = True @dataclass() class DeepSpeed: zero_optimization_level: int = 0 use_compression_training: bool = False compression_bits: int = 8 inferencing: bool = False @cached_property def ds_cfg(self): scheduler_params = {} for k in cfg.hyperparameters.scheduler_params: scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: scheduler_params['total_num_steps'] = cfg.trainer.iterations ds_cfg = { "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, "optimizer": { "type": cfg.hyperparameters.optimizer, "params": { "lr": cfg.hyperparameters.learning_rate, } } if not cfg.hyperparameters.torch_optimizer else None, "scheduler": { "type": cfg.hyperparameters.scheduler_type, "params": scheduler_params, } if cfg.hyperparameters.scheduler_type != "" else None, "gradient_clipping": cfg.hyperparameters.gradient_clipping, "fp16": { "enabled": True, "auto_cast": True, } if cfg.trainer.weight_dtype.lower() == "float16" and not cfg.trainer.amp else None, "bf16": { "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" and not cfg.trainer.amp }, "compression_training": { "weight_quantization": { "shared_parameters":{ "enabled": True, "quantizer_kernel": True, "schedule_offset": 0, "quantize_groups": 64, "quantize_verbose": True, "quantization_type": "symmetric", "rounding": "nearest", "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16 "fp16_mixed_quantize":{ "enabled": False, "quantize_change_ratio": 1 } }, "different_groups": { "wq1": { "params": { "start_bits": self.compression_bits, "target_bits": self.compression_bits, "quantization_period": 0 }, "modules": [ # "^.+?$" "blocks", # for transformer-based models "retnet", # for RetNets-based models ] } } }, "activation_quantization": { "shared_parameters":{ "enabled": True, "quantizer_kernel": True, "schedule_offset": 0, "quantize_groups": 64, "quantize_verbose": True, "quantization_type": "symmetric", "rounding": "nearest", "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16 "fp16_mixed_quantize":{ "enabled": False, "quantize_change_ratio": 1 } }, "different_groups": { "aq1": { "params": { "bits": self.compression_bits, }, "modules": [ # "^.+?$" "blocks", # for transformer-based models "retnet", # for RetNets-based models ] } } }, } if self.use_compression_training else None, "zero_optimization": { "stage": self.zero_optimization_level, "contiguous_gradients": True, "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 5e8, "allgather_bucket_size": 5e8, "sub_group_size": 5e8, "round_robin_gradients": True, "offload_optimizer": { "device": "cpu", "pin_memory": True }, "offload_param": { "device": "cpu", "pin_memory": True }, "zero_quantized_weights": self.use_compression_training, "zero_hpz_partition_size": world_size(), "zero_quantized_gradients": self.use_compression_training, } if self.zero_optimization_level > 0 else None, "comms_logger": { "enabled": False } } null_keys = [ k for k in ds_cfg if not ds_cfg[k] ] for k in null_keys: del ds_cfg[k] if os.path.exists("./data/ds_config.json"): ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8"))) return ds_cfg @dataclass() class Trainer: iterations: int = 100_000 save_tag: str = "step" load_tag: str | None = None save_on_oom: bool = True save_on_quit: bool = True export_on_save: bool = False export_on_quit: bool = False save_frequency: int = 100 keep_last_checkpoints: int = 0 load_state_dict: bool = False load_states: bool = True strict_loading: bool = True load_module_only: bool = False restart_step_count: bool = False activation_checkpointing: bool = True aggressive_optimizations: bool = False check_for_oom: bool = True gc_mode: str | None = None load_disabled_engines: bool = False weight_dtype: str = "float16" amp: bool = False load_webui: bool = False no_logger: bool = False backend: str = "local" deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) @cached_property def dtype(self): if self.weight_dtype == "float16": return torch.float16 if self.weight_dtype == "bfloat16": return torch.bfloat16 if self.weight_dtype == "float8_e5m2": return torch.float8_e5m2 if self.weight_dtype == "float8_e4m3fn": return torch.float8_e4m3fn return torch.float32 @dataclass() class Inference: backend: str = "local" weight_dtype: str = "float32" amp: bool = False normalize: bool = False # do NOT enable this unless you know exactly what you're doing audio_backend: str = "dac" # legacy / backwards compat use_vocos: bool = True use_encodec: bool = True use_dac: bool = True recurrent_chunk_size: int = 0 recurrent_forward: bool = False @cached_property def dtype(self): if self.weight_dtype == "float16": return torch.float16 if self.weight_dtype == "bfloat16": return torch.bfloat16 if self.weight_dtype == "int8": return torch.int8 if self.weight_dtype == "float8_e5m2": return torch.float8_e5m2 if self.weight_dtype == "float8_e4m3fn": return torch.float8_e4m3fn return torch.float32 @dataclass() class BitsAndBytes: enabled: bool = False injects: bool = False replace: bool = False linear: bool = True embedding: bool = True bitnet: bool = False @dataclass() class FP8: enabled: bool = False backend: str = "te" @dataclass() class Config(_Config): device: str = "cuda" mode: str = "training" # "inferencing" experimental: bool = False # So I can stop commenting out things when committing dataset: Dataset = field(default_factory=lambda: Dataset) model: Model = field(default_factory=lambda: Model) models: dict | list | None = None # deprecated hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) inference: Inference = field(default_factory=lambda: Inference) bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) fp8: FP8 = field(default_factory=lambda: FP8) @property def sample_rate(self): return 24_000 @property def distributed(self): return world_size() > 1 @cached_property def get_spkr(self): return eval(self.dataset.speaker_name_getter) @cached_property def get_spkr_group(self): return eval(self.dataset.speaker_group_getter) @cached_property def diskcache(self): if self.cfg_path is not None and self.dataset.cache: return diskcache.Cache(self.cache_dir).memoize return lambda: lambda x: x def load_yaml( self, config_path ): tmp = Config.from_yaml( config_path ) self.__dict__.update(tmp.__dict__) def load_hdf5( self, write=False ): if hasattr(self, 'hdf5'): self.hdf5.close() if self.distributed: self.dataset.hdf5_flag = "r" try: self.hdf5 = h5py.File(f'{self.cfg_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset except Exception as e: print("Error while opening HDF5 file:", f'{self.cfg_path}/{self.dataset.hdf5_name}', str(e)) self.dataset.use_hdf5 = False def format( self ): #if not isinstance(self.dataset, type): self.dataset = Dataset(**self.dataset) self.dataset.training = [ Path(dir) for dir in self.dataset.training ] self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] #if not isinstance(self.model, type): if self.models is not None: self.model = Model(**next(iter(self.models))) else: self.model = Model(**self.model) #if not isinstance(self.hyperparameters, type): self.hyperparameters = Hyperparameters(**self.hyperparameters) #if not isinstance(self.evaluation, type): self.evaluation = Evaluation(**self.evaluation) #if not isinstance(self.trainer, type): self.trainer = Trainer(**self.trainer) if not isinstance(self.trainer.deepspeed, type): self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) #if not isinstance(self.inference, type): self.inference = Inference(**self.inference) #if not isinstance(self.bitsandbytes, type): self.bitsandbytes = BitsAndBytes(**self.bitsandbytes) cfg = Config.from_cli() # OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves try: cfg.format() # cached_property stopped working... if cfg.dataset.use_hdf5: cfg.load_hdf5() except Exception as e: print(e) pass if __name__ == "__main__": print(cfg)