vall-e/vall_e/config.py
2023-09-08 15:36:26 -05:00

564 lines
14 KiB
Python
Executable File

import copy
import diskcache
import h5py
import json
import os
import subprocess
import sys
import time
import torch
from dataclasses import asdict, dataclass
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from omegaconf import OmegaConf
from .utils.distributed import world_size
@dataclass()
class _Config:
cfg_path: str | None = None
@property
def relpath(self):
return Path(self.cfg_path)
@property
def cache_dir(self):
return self.relpath / ".cache"
@property
def ckpt_dir(self):
return self.relpath / "ckpt"
@property
def log_dir(self):
return self.relpath / "logs" / str(self.start_time)
@cached_property
def start_time(self):
return int(time.time())
@cached_property
def git_commit(self):
try:
cmd = "git rev-parse HEAD"
return subprocess.check_output(cmd.split()).decode("utf8").strip()
except:
return ""
@cached_property
def git_status(self):
try:
cmd = "git status"
return subprocess.check_output(cmd.split()).decode("utf8").strip()
except:
return ""
def dumps(self):
data = {k: getattr(self, k) for k in dir(self) if not k.startswith("__")}
data = {k: v for k, v in data.items() if not callable(v)}
return json.dumps(data, indent=2, default=str)
def dump(self, path=None):
if path is None:
path = self.log_dir / "cfg.json"
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
f.write(self.dumps())
@staticmethod
def _is_cfg_argv(s):
return "=" in s and "--" not in s
@classmethod
def from_yaml( cls, yaml_path ):
return cls.from_cli( [f'yaml="{yaml_path}"'] )
@classmethod
def from_cli(cls, args=sys.argv):
cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)])
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)]
if cli_cfg.get("help"):
print(f"Configurable hyperparameters with their default values:")
print(json.dumps(asdict(cls()), indent=2, default=str))
exit()
if "yaml" in cli_cfg:
yaml_cfg = OmegaConf.load(cli_cfg.yaml)
yaml_path = Path(cli_cfg.yaml).absolute()
cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1])
cfg_path = cfg_path.with_suffix("")
cfg_path = f'./{cfg_path}'
yaml_cfg.setdefault("cfg_path", cfg_path)
cli_cfg.pop("yaml")
else:
yaml_cfg = {}
merged = OmegaConf.merge(yaml_cfg, cli_cfg)
return cls(**dict(merged))
def __repr__(self):
return str(self)
def __str__(self):
return self.dumps()
@dataclass()
class Dataset:
training: list[Path] = field(default_factory=lambda: [])
validation: list[Path] = field(default_factory=lambda: [])
noise: list[Path] = field(default_factory=lambda: [])
temp: list[Path] = field(default_factory=lambda: [])
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
hdf5_name: str = "data.h5"
use_hdf5: bool = False
use_metadata: bool = False
hdf5_flag: str = "a"
validate: bool = True
workers: int = 8
cache: bool = True
phones_range: list[int] = field(default_factory=lambda: [4, 256])
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0])
random_utterance: float = 1.0
max_prompts: int = 3
prompt_duration: float = 3.0
sample_type: str = "path" # path | speaker
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
@property
def min_phones(self):
return self.phones_range[0]
@property
def max_phones(self):
return self.phones_range[1]
@property
def min_duration(self):
return self.duration_range[0]
@property
def max_duration(self):
return self.duration_range[1]
@dataclass()
class Model:
name: str = ""
size: str | float | dict = "full"
resp_levels: int = 1
prom_levels: int = 8
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
arch_type: str = "transformer"
training: bool = True
interleave: bool = False
frozen_params: list[str] = field(default_factory=lambda: [])
@property
def full_name(self):
name = [ self.name ]
if self.size != "full" and isinstance(self.size, str):
name.append(self.size)
if self.arch_type != "transformer":
name.append(self.arch_type.replace("/", "-"))
if self.interleave:
name.append("interleaved")
name.append(f'{cfg.models.prom_levels}')
return "-".join(name)
@property
def tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "tokens"):
return self.size['tokens']
return 1024
@property
def dim(self):
if isinstance(self.size, dict) and hasattr(self.size, "dim"):
return self.size['dim']
if isinstance(self.size, float):
return math.floor(1024 * self.size)
if self.size == "quarter":
return 256
if self.size == "half":
return 512
return 1024
@property
def heads(self):
if isinstance(self.size, dict) and hasattr(self.size, "heads"):
return self.size['heads']
if isinstance(self.size, float):
return math.floor(16 * self.size)
if self.size == "quarter":
return 4
if self.size == "half":
return 8
return 16
@property
def layers(self):
if isinstance(self.size, dict) and hasattr(self.size, "layers"):
return self.size['layers']
if self.size == "double":
return 24
return 12
@property
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing
@dataclass()
class Models:
_max_levels: int = 0
_prom_levels: int = 1
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False),
])
def get(self, name=None):
if not name:
return [ Model(**model) for model in self._models ]
for model in self._models:
if model.name == name:
return model
raise ValueError
@property
def ar(self):
return self.get("ar")
@property
def ar_nar(self):
return self.get("ar+nar")
@property
def nar(self):
return self.get("nar")
@property
def prom_levels(self):
prom_levels = self._prom_levels
for model in self._models:
prom_levels = max(prom_levels, model.prom_levels)
return prom_levels
@property
def tasks(self):
tasks = 1
for model in self._models:
tasks = max(tasks, model.tasks)
return tasks
@property
def max_levels(self):
return self._max_levels if self._max_levels > 0 else self.prom_levels
@dataclass()
class Hyperparameters:
batch_size: int = 8
gradient_accumulation_steps: int = 32
gradient_clipping: int = 100
optimizer: str = "Adamw"
torch_optimizer: bool = False
optimizer_params: dict = field(default_factory=lambda: {})
learning_rate: float = 3.25e-4
scheduler_type: str = ""
scheduler_params: dict = field(default_factory=lambda: {})
@dataclass()
class Evaluation:
batch_size: int = 64
frequency: int = 250
size: int = 64
steps: int = 500
ar_temperature: float = 1.0
nar_temperature: float = 0.2
load_disabled_engines: bool = True
@dataclass()
class DeepSpeed:
zero_optimization_level: int = 0
use_compression_training: bool = False
compression_bits: int = 8
@cached_property
def ds_cfg(self):
scheduler_params = {}
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations
ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
"optimizer": {
"type": cfg.hyperparameters.optimizer,
"params": {
"lr": cfg.hyperparameters.learning_rate,
}
} if not cfg.hyperparameters.torch_optimizer else None,
"scheduler": {
"type": cfg.hyperparameters.scheduler_type,
"params": scheduler_params,
} if cfg.hyperparameters.scheduler_type != "" else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
"enabled": True,
"auto_cast": True,
} if cfg.trainer.weight_dtype.lower() == "float16" and not cfg.trainer.amp else None,
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" and not cfg.trainer.amp
},
"compression_training": {
"weight_quantization": {
"shared_parameters":{
"enabled": True,
"quantizer_kernel": True,
"schedule_offset": 0,
"quantize_groups": 64,
"quantize_verbose": True,
"quantization_type": "symmetric",
"rounding": "nearest",
"quantize_weight_in_forward": True,
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
}
},
"different_groups": {
"wq1": {
"params": {
"start_bits": self.compression_bits,
"target_bits": self.compression_bits,
"quantization_period": 0
},
"modules": [
"blocks", # for transformer-based models
"retnet", # for RetNets-based models
]
}
}
},
} if self.use_compression_training else None,
"zero_optimization": {
"stage": self.zero_optimization_level,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
"sub_group_size": 5e8,
"round_robin_gradients": True,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
},
"zero_quantized_weights": self.use_compression_training,
"zero_hpz_partition_size": world_size(),
"zero_quantized_gradients": self.use_compression_training,
} if self.zero_optimization_level > 0 else None,
"comms_logger": {
"enabled": False
}
}
null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
for k in null_keys:
del ds_cfg[k]
if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
return ds_cfg
@dataclass()
class Trainer:
iterations: int = 100_000
save_tag: str = "step"
load_tag: str | None = None
save_on_oom: bool = True
save_on_quit: bool = True
export_on_save: bool = False
export_on_quit: bool = False
save_frequency: int = 100
keep_last_checkpoints: int = 0
load_state_dict: bool = False
load_states: bool = True
strict_loading: bool = True
load_module_only: bool = False
restart_step_count: bool = False
activation_checkpointing: bool = True
aggressive_optimizations: bool = False
check_for_oom: bool = True
gc_mode: str | None = None
load_disabled_engines: bool = False
weight_dtype: str = "float16"
amp: bool = False
backend: str = "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
return torch.float32
@dataclass()
class Inference:
weight_dtype: str = "float32"
amp: bool = False
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
use_vocos: bool = True
recurrent_chunk_size: int = 0
recurrent_forward: bool = False
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
return torch.float32
@dataclass()
class BitsAndBytes:
enabled: bool = False
injects: bool = False
linear: bool = True
embedding: bool = True
@dataclass()
class Config(_Config):
device: str = "cuda"
mode: str = "training" # "inferencing"
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
@property
def sample_rate(self):
return 24_000
@property
def distributed(self):
return world_size() > 1
@cached_property
def get_spkr(self):
return eval(self.dataset.speaker_name_getter)
@cached_property
def diskcache(self):
if self.cfg_path is not None and self.dataset.cache:
return diskcache.Cache(self.cache_dir).memoize
return lambda: lambda x: x
def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__)
def load_hdf5( self, write=False ):
if hasattr(self, 'hdf5'):
self.hdf5.close()
if self.distributed:
self.dataset.hdf5_flag = "r"
try:
self.hdf5 = h5py.File(f'{self.cfg_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
print("Error while opening HDF5 file:", f'{self.cfg_path}/{self.dataset.hdf5_name}', str(e))
self.dataset.use_hdf5 = False
def format( self ):
self.dataset = Dataset(**self.dataset)
self.models = Models(**self.models)
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer)
self.inference = Inference(**self.inference)
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
cfg = Config.from_cli()
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
try:
cfg.format()
# cached_property stopped working...
if cfg.dataset.use_hdf5:
cfg.load_hdf5()
except Exception as e:
pass
if __name__ == "__main__":
print(cfg)