updated framework to use the saner framework that mrq/vall-e uses these days

This commit is contained in:
mrq 2024-09-04 15:48:29 -05:00
parent 5cb28a210e
commit 5610bb3bb3
28 changed files with 3438 additions and 736 deletions

0
.gitignore vendored Executable file → Normal file
View File

0
LICENSE Executable file → Normal file
View File

40
README.md Executable file → Normal file
View File

@ -1,10 +1,10 @@
# Tentative Title For A ResNet-Based Image Classifier
This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
This is a simple ResNet based image classifier for images, using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
## Premise
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I faced.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
@ -16,44 +16,14 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
3. Install using `pip3 install -e ./image_classifier/`.
4. Train using `python3 -m image_classifier.train yaml='./data/config.yaml'`.
4. Train using `python3 -m image_classifier.train --yaml='./data/config.yaml'`.
5. Wait.
## Inferencing
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" --yaml="./data/config.yaml"`
### Continuous Usage
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
## Known Issues
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
## Strawmen
>\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!!
I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests.
>\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]`
Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader).
Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there.
>\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`??
I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment.
>\> UGH!!! What are you talking about """specific images"""???
[ひみつ](https://files.catbox.moe/csuh49.webm)
>\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`???
:)
If you're looking to continuously classify images, use `python3 -m image_classifier --listen --port=7860 --yaml="./data/config.yaml"` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.

131
data/config.yaml Executable file → Normal file
View File

@ -1,85 +1,84 @@
dataset:
training: [
"./data/images/"
]
validation: []
use_hdf5: False
workers: 0
cache: True
weights_format: sft
models:
_models:
- name: "classifier"
tokens: 0
len: 6
- name: "classifier"
tokens: 0
len: 6
dim: 512
resnet: 34
#loras:
#- name : "lora"
# rank: 128
# alpha: 128
# training: True
# rvq_levels: []
hyperparameters:
batch_size: 256
gradient_accumulation_steps: 64
gradient_clipping: 100
gradient_accumulation_steps: 1
gradient_clipping: 1.0
warmup_steps: 10
optimizer: Prodigy
learning_rate: 1.0
torch_optimizer: True
optimizer: Adamw
learning_rate: 1.0e-3
scheduler_type: ""
#scheduler_type: OneCycle
#scheduler_params:
# cycle_first_step_size: 10_000
# cycle_first_stair_count: 10_000
# cycle_second_step_size: 15_000
# cycle_second_stair_count: 15_000
# decay_step_size: 5_000
# cycle_min_lr: 2.5e-4 # 1.0e-5
# cycle_max_lr: 2.5e-4 # 1.0e-4
# decay_lr_rate: 0.0
# cycle_min_mom: 0.90
# cycle_max_mom: 0.99
# decay_mom_rate: 0.0
scheduler: "" # ScheduleFree
torch_scheduler: True
evaluation:
batch_size: 32
frequency: 250
size: 32
batch_size: 64
frequency: 100
size: 64
steps: 300
temperature: 1.0
steps: 450
temperature: 0.0
trainer:
iterations: 100_000
save_tag: step
save_on_oom: True
save_on_quit: True
iterations: 1_000_000
save_frequency: 100
aggressive_optimizations: False
check_for_oom: False
#load_tag: "9500"
#load_state_dict: True
#load_states: False
#strict_loading: False
#restart_step_count: True
keep_last_checkpoints: 32
gc_mode: None # "global_step"
check_for_oom: False
gradient_checkpointing: True
weight_dtype: float32
weight_dtype: bfloat16
amp: True
backend: local
backend: deepspeed
deepspeed:
zero_optimization_level: 0
use_compression_training: True
inferencing: False
amp: False
inference:
use_vocos: True
backend: local
bitsandbytes:
enabled: false
weight_dtype: bfloat16
amp: True
optimizations:
injects: False
replace: True
linear: False
embedding: False
optimizers: True
bitsandbytes: False
dadaptation: False
bitnet: False
fp8: False
dataset:
use_hdf5: True
hdf5_flag: r
workers: 1
cache: True
training: [
"./data/images/"
]
validation: [
"./data/validation/"
]

View File

@ -12,35 +12,55 @@ def main():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", action='store_true')
parser.add_argument("--port", type=int, default=9090)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ckpt", type=Path, default=None)
parser.add_argument("--temp", type=float, default=1.0)
parser.add_argument("--device", default="cuda")
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None)
parser.add_argument("--temp", type=float, default=0.0)
args, unknown = parser.parse_known_args()
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
if args.listen:
@route("/")
def inference( b64, temperature=1.0 ):
def inference( b64, temperature=args.temp ):
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
return { "answer": classifier.inference( image=image, temperature=args.temp ) }
return { "answer": classifier.inference( image=image, temperature=temperature ) }
server.start(port=args.port)
else:
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--path", type=Path)
parser.add_argument("--base64", type=str)
parser.add_argument("--write", type=Path)
parser.add_argument("--temp", type=float, default=1.0)
args, unknown = parser.parse_known_args()
args, unknown = parser.parse_known_args()
images = []
if args.path:
image = Image.open(args.path).convert('RGB')
if args.path.is_dir():
for p in args.path.rglob("./*.jpg"):
image = Image.open(p).convert('RGB')
images.append(image)
for p in args.path.rglob("./*.png"):
image = Image.open(p).convert('RGB')
images.append(image)
else:
image = Image.open(args.path).convert('RGB')
images.append(image)
elif args.base64:
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
images.append(image)
else:
raise "Specify a --path or --base64."
answer = classifier.inference( image=image, temperature=args.temp )
print("Answer:", answer)
for image in images:
answer = classifier.inference( image=image, temperature=args.temp )
print("Answer:", answer)
if args.write:
args.write.mkdir(exist_ok=True)
image.save( args.write / f"{answer}.jpg")
if __name__ == "__main__":
main()

View File

@ -6,31 +6,61 @@ import os
import subprocess
import sys
import time
from dataclasses import asdict, dataclass
from dataclasses import dataclass, field
from functools import cached_property, cache
from pathlib import Path
from omegaconf import OmegaConf
import argparse
import yaml
import random
import logging
import torch
import numpy as np
from dataclasses import asdict, dataclass, field
from functools import cached_property
from pathlib import Path
from .utils.distributed import world_size
def set_seed(seed=None):
if not seed:
seed = time.time()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@dataclass()
class _Config:
cfg_path: str | None = None
class BaseConfig:
yaml_path: str | None = None # path passed in through --yaml
@property
def relpath(self):
def cfg_path(self):
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
@property
def rel_path(self):
return Path(self.cfg_path)
@property
def cache_dir(self):
return self.rel_path / ".cache"
@property
def data_dir(self):
return self.rel_path / "data"
@property
def metadata_dir(self):
return self.rel_path / "metadata"
@property
def ckpt_dir(self):
return self.relpath / "ckpt"
return self.rel_path / "ckpt"
@property
def log_dir(self):
return self.relpath / "logs" / str(self.start_time)
return self.rel_path / "logs" / str(self.start_time)
@cached_property
def start_time(self):
@ -64,39 +94,28 @@ class _Config:
with open(path, "w") as f:
f.write(self.dumps())
@staticmethod
def _is_cfg_argv(s):
return "=" in s and "--" not in s
@classmethod
def from_yaml( cls, yaml_path ):
return cls.from_cli( [f'yaml="{yaml_path}"'] )
state = {}
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
state.setdefault("yaml_path", yaml_path)
return cls(**state)
@classmethod
def from_cli(cls, args=sys.argv):
cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)])
# legacy support for yaml=`` format
for i, arg in enumerate(args):
if arg.startswith("yaml"):
args[i] = f'--{arg}'
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)]
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
args, unknown = parser.parse_known_args(args=args)
if cli_cfg.get("help"):
print(f"Configurable hyperparameters with their default values:")
print(json.dumps(asdict(cls()), indent=2, default=str))
exit()
if args.yaml:
return cls.from_yaml( args.yaml )
if "yaml" in cli_cfg:
yaml_cfg = OmegaConf.load(cli_cfg.yaml)
yaml_path = Path(cli_cfg.yaml).absolute()
cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1])
cfg_path = cfg_path.with_suffix("")
cfg_path = f'./{cfg_path}'
yaml_cfg.setdefault("cfg_path", cfg_path)
cli_cfg.pop("yaml")
else:
yaml_cfg = {}
merged = OmegaConf.merge(yaml_cfg, cli_cfg)
return cls(**dict(merged))
return cls(**{})
def __repr__(self):
return str(self)
@ -106,104 +125,195 @@ class _Config:
@dataclass()
class Dataset:
training: list[Path] = field(default_factory=lambda: [])
validation: list[Path] = field(default_factory=lambda: [])
temp: list[Path] = field(default_factory=lambda: [])
# de-implemented, because the data isn't that large to facilitate HDF5
hdf5_name: str = "data.h5"
use_hdf5: bool = False
training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
workers: int = 8
cache: bool = True
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
use_hdf5: bool = False # whether to load from an HDF5 dataset
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
workers: int = 8 # number of dataloader workers to spawn
cache: bool = True # use diskcache to cache the dataset
# I really need to clean this up
@dataclass()
class Model:
name: str = ""
name: str = "classifier"
tokens: int = 0 # number of token types
len: int = 1 # how long a sequence can be
dim: int = 512
resnet: int = 18
width: int = 300
height: int = 80
version: int = 1
training: bool = True
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
@property
def full_name(self):
return self.name
@dataclass()
class Models:
_models: list[Model] = field(default_factory=lambda: [
Model(name="captcha"),
])
def get(self, name=None):
if not name:
return [ Model(**model) for model in self._models ]
return [ self ] if not name or self.name == name else []
def loss_factor(self, k):
return self.loss_factors[k] if k in self.loss_factors else 1.0
for model in self._models:
if model.name == name:
return model
@property
# required for fp8 as the lengths needs to be divisible by 8
def input_alignment(self):
return 8 if cfg.optimizations.fp8 else 0
raise ValueError
@property
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing
@property
def gradient_checkpointing(self):
return cfg.trainer.gradient_checkpointing
@property
def lora_policy(self):
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
exclude = []
if self.arch_type == "llama":
include = ["self_attn", "mlp"] # target only the attention + mlp
exclude = ["self_attn.k_proj"] # common literature says to ignore it
if self.arch_type == "retnet":
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
return dict(include=include, exclude=exclude)
# should be renamed to Adapters
@dataclass()
class LoRA:
name: str = "lora" # vanity name
# to-do: find sane default values
rank: int = 128 # rank for the LoRA
alpha: int = 128 # rank for the LoRA
training: bool = True #
embeddings: bool = False # train the embedding too
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
@property
def full_name(self):
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
return "-".join(name)
# actually not needed anymore
def active_level( self, level ):
if not self.rvq_levels:
return True
return level in self.rvq_levels
@dataclass()
class Hyperparameters:
batch_size: int = 8
gradient_accumulation_steps: int = 32
gradient_clipping: int = 100 # to be implemented in the local backend
batch_size: int = 8 # number of samples per training batch
gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
gradient_clipping: int | float = 10 # largest size a gradient norm can be
optimizer: str = "Adamw"
learning_rate: float = 3.25e-4
optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
scheduler_type: str = "" # to be implemented in the local backend
scheduler_params: dict = field(default_factory=lambda: {})
scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
scheduler_type: str = "" # deprecated
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
autotune: bool = False # to do deepspeed's autotuning
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
@dataclass()
class Evaluation:
batch_size: int = 64
frequency: int = 250
size: int = 64
batch_size: int = 64 # number of samples per batch during eval / val
frequency: int = 250 # do eval / val every X iterations
size: int = 64 # number of samples to generate during eval / val
steps: int = 500
temperature: float = 1.0
temperature: float = 1.0 # AR temp for inferencing
load_disabled_engines: bool = True # see the other load_disabled_engines
@dataclass()
class DeepSpeed:
zero_optimization_level: int = 0
use_compression_training: bool = False
zero_optimization_level: int = 0 # doesn't seem to work
use_compression_training: bool = False # cope
compression_bits: int = 8 # cope
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
def get_ds_cfg(self, model):
weights = [ name[0] for name in model.named_parameters() ]
bits = 8
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
scheduler_params = {}
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
@cached_property
def ds_cfg(self):
optimizer_params = cfg.hyperparameters.optimizer_params
if 'lr' not in optimizer_params:
optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
scheduler_params = cfg.hyperparameters.scheduler_params
if 'warmup_num_steps' not in scheduler_params:
scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
if 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations
autotune_params = cfg.hyperparameters.autotune_params
if "enabled" not in autotune_params:
autotune_params['enabled'] = True
if "results_dir" not in autotune_params:
autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" )
if "exps_dir" not in autotune_params:
autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" )
# DeepSpeed fp16 is incompatible with its AMP
if cfg.trainer.weight_dtype.lower() == "float16":
self.amp = False
# disable local AMP
if self.amp:
cfg.trainer.amp = False
ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
"optimizer": {
"type": cfg.hyperparameters.optimizer,
"params": {
"lr": cfg.hyperparameters.learning_rate,
}
},
"params": optimizer_params,
} if not cfg.hyperparameters.torch_optimizer else None,
"scheduler": {
"type": cfg.hyperparameters.scheduler_type,
"type": cfg.hyperparameters.scheduler,
"params": scheduler_params,
} if cfg.hyperparameters.scheduler_type != "" else None,
} if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
"enabled": True,
"auto_cast": True,
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": True, # ???
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
},
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
},
"amp": {
"enabled": self.amp,
},
"autotuning": autotune_params if cfg.hyperparameters.autotune else None,
"compression_training": {
"weight_quantization": {
"shared_parameters":{
@ -214,7 +324,7 @@ class DeepSpeed:
"quantize_verbose": True,
"quantization_type": "symmetric",
"rounding": "nearest",
"quantize_weight_in_forward": True,
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
@ -223,30 +333,38 @@ class DeepSpeed:
"different_groups": {
"wq1": {
"params": {
"start_bits": bits,
"target_bits": bits,
"start_bits": self.compression_bits,
"target_bits": self.compression_bits,
"quantization_period": 0
},
"modules": weights
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
}
}
},
"activation_quantization": {
"shared_parameters":{
"enabled": True,
"quantizer_kernel": True,
"schedule_offset": 0,
"quantize_groups": 64,
"quantize_verbose": True,
"quantization_type": "symmetric",
"range_calibration": "dynamic",
"schedule_offset": 0
"rounding": "nearest",
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
}
},
"different_groups": {
"aq1": {
"params": {
"bits": bits
"bits": self.compression_bits,
},
"modules": weights
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
}
}
}
},
} if self.use_compression_training else None,
"zero_optimization": {
"stage": self.zero_optimization_level,
@ -264,7 +382,10 @@ class DeepSpeed:
"offload_param": {
"device": "cpu",
"pin_memory": True
}
},
"zero_quantized_weights": self.use_compression_training,
"zero_hpz_partition_size": world_size(),
"zero_quantized_gradients": self.use_compression_training,
} if self.zero_optimization_level > 0 else None,
"comms_logger": {
"enabled": False
@ -275,113 +396,314 @@ class DeepSpeed:
for k in null_keys:
del ds_cfg[k]
if os.path.exists("./config/ds_config.json"):
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
else:
ds_cfg.update(self.config)
return ds_cfg
@dataclass()
class Trainer:
iterations: int = 100_000
iterations: int = 1_000_000 # maximum iterations to train
save_tag: str = "step"
load_tag: str | None = None
save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
save_on_oom: bool = True
save_on_quit: bool = True
save_frequency: int = 100
save_on_oom: bool = True # save if an OOM error is raised
save_on_quit: bool = True # save when quitting training
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
save_frequency: int = 100 # frequency to save every X iterations
load_state_dict: bool = False
load_states: bool = True
strict_loading: bool = True
restart_step_count: bool = False
keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
aggressive_optimizations: bool = False
check_for_oom: bool = True
load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
load_states: bool = True #
strict_loading: bool = False # sets strict_loading=True when loading the state dict
load_module_only: bool = False #
restart_step_count: bool = False # clears the training stats when loading a checkpoint
resize_modules: bool = False # automatically resizes
gc_mode: str | None = None
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
weight_dtype: str = "float16"
aggressive_optimizations: bool = False # deprecated
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
backend: str = "deepspeed"
weight_dtype: str = "float16" # dtype to have the model under
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
amp: bool = False # automatic mixed precision
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if cfg.trainer.weight_dtype == "bfloat16":
if self.weight_dtype == "bfloat16":
return torch.bfloat16
if self.weight_dtype == "float8_e5m2":
return torch.float8_e5m2
if self.weight_dtype == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32
@cached_property
def scale_loss(self):
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
return self.dtype == torch.float16
@dataclass()
class Inference:
use_vocos: bool = True # artifact from the VALL-E trainer
backend: str = "local" # backend to use when inferencing
weight_dtype: str = "float32" # dtype to load the model under
amp: bool = False # automatic mixed precision during inferencing
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
@property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
if self.weight_dtype == "int8":
return torch.int8
if self.weight_dtype == "float8_e5m2":
return torch.float8_e5m2
if self.weight_dtype == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32
@dataclass()
class BitsAndBytes:
enabled: bool = False
injects: bool = False
class Optimizations:
injects: bool = False # overwrites default torch classes (not recommended)
replace: bool = False # replaces modules in place with the optimized version (recommended)
compile: bool | str = False # runs torch.compile on the model
linear: bool = False
embedding: bool = False
linear: bool = True # inject/replace linear for BnB
embedding: bool = True # inject/replace embedding for BnB
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitsandbytes: bool = False # use bitsandbytes
dadaptation: bool = False # use dadaptation optimizer
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
model_offloading: dict | None = None # automatically splits the model over a list of devices
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
tensorrt: bool = False
@dataclass()
class Config(_Config):
device: str = "cuda"
class Config(BaseConfig):
device: str = "cuda" # target device
mode: str = "training" # "inferencing"
experimental: bool = False # Debug flag, unused now
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)
models: dict | list | None = field(default_factory=lambda: [])
loras: dict | list | None = field(default_factory=lambda: [])
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
def get_device(self):
return torch.cuda.current_device() if self.device == "cuda" else self.device
tokenizer: str | None = None # tokenizer class
tokenizer_path: str = "./tokenizer.json" # tokenizer path
weights_format: str = "pth" # "pth" | "sft"
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
@property
def cache_dir(self):
return ".cache" / self.relpath
def model(self):
for i, model in enumerate(self.models):
if model.training:
return model
return self.models[0] if len(self.models) > 0 else None
# should be renamed to adapters
@property
def lora(self):
for i, lora in enumerate(self.loras):
if lora.training:
return lora
return self.loras[0] if len(self.loras) > 0 else None
@property
def distributed(self):
return world_size() > 1
@cached_property
def diskcache(self):
if self.dataset.cache:
if self.yaml_path is not None and self.dataset.cache:
return diskcache.Cache(self.cache_dir).memoize
return lambda: lambda x: x
# I don't remember why this is needed
def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__)
def load_hdf5( self, write=False ):
if hasattr(self, 'hdf5'):
self.hdf5.close()
if self.distributed:
self.dataset.hdf5_flag = "r"
try:
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
self.dataset.use_hdf5 = False
# to-do: prune unused keys
def format( self, training=True ):
if isinstance(self.dataset, type):
self.dataset = dict()
if isinstance(self.models, type):
self.models = dict()
if isinstance(self.loras, type):
self.loras = dict()
if isinstance(self.hyperparameters, type):
self.hyperparameters = dict()
if isinstance(self.evaluation, type):
self.evaluation = dict()
if isinstance(self.trainer, type):
self.trainer = dict()
if isinstance(self.inference, type):
self.inference = dict()
if isinstance(self.optimizations, type):
self.optimizations = dict()
self.dataset = Dataset(**self.dataset)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.models = [ Model(**model) for model in self.models ]
self.loras = [ LoRA(**lora) for lora in self.loras ]
if not self.models:
self.models = [ Model() ]
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer)
if not isinstance(self.trainer.deepspeed, type):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
self.inference = Inference(**self.inference)
if self.bitsandbytes is not None:
self.optimizations = Optimizations(**self.bitsandbytes)
else:
self.optimizations = Optimizations(**self.optimizations)
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
self.hyperparameters.scheduler_type = ""
# do not combine the two
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
self.hyperparameters.scheduler = ""
if self.hyperparameters.scheduler == "":
self.hyperparameters.torch_scheduler = True
if self.trainer.backend == "local" and self.distributed:
self.trainer.ddp = True
if self.trainer.activation_checkpointing is not None:
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
if not training:
self.dataset.use_hdf5 = False
# load our HDF5 file if requested here
if self.dataset.use_hdf5:
self.load_hdf5()
# load tokenizer
if cfg.tokenizer == "naive":
cfg.tokenizer = NaiveTokenizer()
else:
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
if tokenizer_path and not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer_path
if tokenizer_path and tokenizer_path.exists():
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
else:
cfg.tokenizer = NaiveTokenizer()
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
pass
# Preserves the old behavior
class NaiveTokenizer:
def get_vocab( self ):
"""
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
"""
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
@cached_property
def _bos_token( self ):
return self.get_vocab()["<s>"]
@cached_property
def _eos_token( self ):
return self.get_vocab()["</s>"]
def encode( self, s ):
symmap = self.get_vocab()
s = s.replace("O", "0")
s = [f"<s>"] + [ p if p in symmap else " " for p in s ] + [f"</s>"]
return [*map(symmap.get, s)]
_logger = logging.getLogger(__name__)
cfg = Config.from_cli()
# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves
cfg.dataset = Dataset(**cfg.dataset)
cfg.models = Models(**cfg.models)
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
cfg.evaluation = Evaluation(**cfg.evaluation)
cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
# cached_property stopped working...
if cfg.dataset.use_hdf5:
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
try:
cfg.format()
except Exception as e:
_logger.error(f"Error while parsing config YAML: {str(e)}")
raise e # throw an error because I'm tired of silent errors messing things up for me
if __name__ == "__main__":
print(cfg)
print(cfg)

View File

@ -1,16 +1,19 @@
# todo: clean this mess up
import copy
# import h5py
import h5py
import json
import logging
#import numpy as np
import numpy as np
import os
import random
import torch
import math
import itertools
from .config import cfg
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from .utils.io import torch_save, torch_load
from collections import defaultdict
from functools import cache, cached_property
@ -20,23 +23,57 @@ from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from PIL import Image
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
@cache
# to-do: clean up this symmap mess
def get_symmap():
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
return cfg.tokenizer.get_vocab()
@cache
def _get_symbols( content ):
content = content.replace("O", "0")
return [f"<s>"] + [ p for p in content ] + [f"</s>"]
def tokenize( s ):
if isinstance( s, list ):
s = "".join( s )
return cfg.tokenizer.encode( s )
"""
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_hdf5_path(path):
# to-do: better validation
return str(path)
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
key = f"/{type}/{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, validate=False ):
if isinstance(path, str):
path = Path(path)
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
groups[fn(e)].append(e)
groups = {k: groups[k] for k in sorted(groups)}
for interleaved in zip_longest(*groups.values()):
for value in interleaved:
if value is not None:
yield value
"""
class Dataset(_Dataset):
def __init__(
@ -44,43 +81,90 @@ class Dataset(_Dataset):
paths,
width=300,
height=80,
stacks=0,
symmap=get_symmap(),
training=False,
):
super().__init__()
self._head = None
self.paths = paths
self.sampler = None
self.width = width
self.height = height
self.stacks = stacks
self.paths = paths
self.image_dtype = cfg.trainer.dtype
self.symmap = symmap
self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.transform = transforms.Compose([
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@cached_property
def symbols(self):
return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths]))
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len(self.dataset) == 0:
self.dataset = cfg.dataset.training
# split dataset accordingly per GPU
if cfg.distributed and self.training:
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
@cached_property
def sampler_state_dict_path(self):
return cfg.rel_path / f"sampler.rank{global_rank()}.pt"
def save_state_dict(self, path = None):
"""
if path is None:
path = self.sampler_state_dict_path
if self.sampler is not None:
state_dict = self.sampler.get_state()
elif self.samplers is not None:
state_dict = {
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
}
torch_save(state_dict, path)
"""
return
def load_state_dict(self, path = None):
"""
if path is None:
path = self.sampler_state_dict_path
if not path.exists():
return
state_dict = torch_load(path)
if self.sampler is not None:
state_dict = self.sampler.set_state(state_dict)
else:
for name, sampler in state_dict["samplers"].items():
if name not in self.samplers:
continue
self.samplers[name].set_state( sampler )
"""
return
def __getitem__(self, index):
path = self.paths[index]
tokens = tokenize( path.stem.upper() )
text = torch.tensor( tokens ).to(dtype=torch.uint8)
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
try:
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
except Exception as e:
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
raise e
image = Image.open(path).convert('RGB')
width, height = image.size
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
return dict(
index=index,
@ -98,11 +182,6 @@ class Dataset(_Dataset):
def __len__(self):
return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self):
self.text = self.text.pin_memory()
self.image = self.image.pin_memory()
return self
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
@ -111,21 +190,28 @@ def collate_fn(samples: list[dict]):
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
#np.random.seed(worker_seed)
np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dataloader(dataset, training):
kwargs = dict(
shuffle=True,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
drop_last=training,
sampler=dataset.sampler if training else None,
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
batch_sampler=dataset.sampler,
)
return DataLoader(
dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=True, # training
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 0,
pin_memory=False, # True,
persistent_workers=cfg.dataset.workers > 1,
pin_memory=False,
worker_init_fn=_seed_worker,
**kwargs,
)
def _load_train_val_paths( val_ratio=0.1 ):
@ -133,8 +219,8 @@ def _load_train_val_paths( val_ratio=0.1 ):
train_paths = []
val_paths = []
print(cfg.dataset.training)
for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
@ -146,12 +232,13 @@ def _load_train_val_paths( val_ratio=0.1 ):
val_len = math.floor(len(train_paths) * val_ratio)
train_len = math.floor(len(train_paths) * (1 - val_ratio))
print(val_len, train_len)
val_paths = train_paths[:-val_len]
train_paths = train_paths[:train_len]
else:
paths = []
for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
@ -169,7 +256,6 @@ def _load_train_val_paths( val_ratio=0.1 ):
return train_paths, val_paths
@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
@ -187,10 +273,10 @@ def create_datasets():
return train_dataset, val_dataset
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
# deepcopy is slow
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.training_(False)
@ -200,8 +286,6 @@ def create_train_val_dataloader():
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.symmap))