updated framework to use the saner framework that mrq/vall-e uses these days

This commit is contained in:
mrq 2024-09-04 15:48:29 -05:00
parent 5cb28a210e
commit 5610bb3bb3
28 changed files with 3438 additions and 736 deletions

0
.gitignore vendored Executable file → Normal file
View File

0
LICENSE Executable file → Normal file
View File

40
README.md Executable file → Normal file
View File

@ -1,10 +1,10 @@
# Tentative Title For A ResNet-Based Image Classifier
This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
This is a simple ResNet based image classifier for images, using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
## Premise
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I faced.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
@ -16,44 +16,14 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
3. Install using `pip3 install -e ./image_classifier/`.
4. Train using `python3 -m image_classifier.train yaml='./data/config.yaml'`.
4. Train using `python3 -m image_classifier.train --yaml='./data/config.yaml'`.
5. Wait.
## Inferencing
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" --yaml="./data/config.yaml"`
### Continuous Usage
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
## Known Issues
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
## Strawmen
>\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!!
I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests.
>\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]`
Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader).
Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there.
>\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`??
I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment.
>\> UGH!!! What are you talking about """specific images"""???
[ひみつ](https://files.catbox.moe/csuh49.webm)
>\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`???
:)
If you're looking to continuously classify images, use `python3 -m image_classifier --listen --port=7860 --yaml="./data/config.yaml"` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.

121
data/config.yaml Executable file → Normal file
View File

@ -1,85 +1,84 @@
dataset:
training: [
"./data/images/"
]
validation: []
use_hdf5: False
workers: 0
cache: True
weights_format: sft
models:
_models:
- name: "classifier"
tokens: 0
len: 6
dim: 512
resnet: 34
#loras:
#- name : "lora"
# rank: 128
# alpha: 128
# training: True
# rvq_levels: []
hyperparameters:
batch_size: 256
gradient_accumulation_steps: 64
gradient_clipping: 100
gradient_accumulation_steps: 1
gradient_clipping: 1.0
warmup_steps: 10
optimizer: Adamw
learning_rate: 1.0e-3
optimizer: Prodigy
learning_rate: 1.0
torch_optimizer: True
scheduler_type: ""
#scheduler_type: OneCycle
#scheduler_params:
# cycle_first_step_size: 10_000
# cycle_first_stair_count: 10_000
# cycle_second_step_size: 15_000
# cycle_second_stair_count: 15_000
# decay_step_size: 5_000
# cycle_min_lr: 2.5e-4 # 1.0e-5
# cycle_max_lr: 2.5e-4 # 1.0e-4
# decay_lr_rate: 0.0
# cycle_min_mom: 0.90
# cycle_max_mom: 0.99
# decay_mom_rate: 0.0
scheduler: "" # ScheduleFree
torch_scheduler: True
evaluation:
batch_size: 32
frequency: 250
size: 32
batch_size: 64
frequency: 100
size: 64
steps: 300
temperature: 1.0
steps: 450
temperature: 0.0
trainer:
iterations: 100_000
save_tag: step
save_on_oom: True
save_on_quit: True
iterations: 1_000_000
save_frequency: 100
aggressive_optimizations: False
keep_last_checkpoints: 32
check_for_oom: False
gradient_checkpointing: True
#load_tag: "9500"
#load_state_dict: True
#load_states: False
#strict_loading: False
#restart_step_count: True
weight_dtype: bfloat16
amp: True
gc_mode: None # "global_step"
weight_dtype: float32
backend: local
backend: deepspeed
deepspeed:
zero_optimization_level: 0
use_compression_training: True
inferencing: False
amp: False
inference:
use_vocos: True
backend: local
bitsandbytes:
enabled: false
weight_dtype: bfloat16
amp: True
optimizations:
injects: False
replace: True
linear: False
embedding: False
optimizers: True
bitsandbytes: False
dadaptation: False
bitnet: False
fp8: False
dataset:
use_hdf5: True
hdf5_flag: r
workers: 1
cache: True
training: [
"./data/images/"
]
validation: [
"./data/validation/"
]

View File

@ -12,35 +12,55 @@ def main():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", action='store_true')
parser.add_argument("--port", type=int, default=9090)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ckpt", type=Path, default=None)
parser.add_argument("--temp", type=float, default=1.0)
parser.add_argument("--device", default="cuda")
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None)
parser.add_argument("--temp", type=float, default=0.0)
args, unknown = parser.parse_known_args()
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
if args.listen:
@route("/")
def inference( b64, temperature=1.0 ):
def inference( b64, temperature=args.temp ):
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
return { "answer": classifier.inference( image=image, temperature=args.temp ) }
return { "answer": classifier.inference( image=image, temperature=temperature ) }
server.start(port=args.port)
else:
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--path", type=Path)
parser.add_argument("--base64", type=str)
parser.add_argument("--write", type=Path)
parser.add_argument("--temp", type=float, default=1.0)
args, unknown = parser.parse_known_args()
images = []
if args.path:
if args.path.is_dir():
for p in args.path.rglob("./*.jpg"):
image = Image.open(p).convert('RGB')
images.append(image)
for p in args.path.rglob("./*.png"):
image = Image.open(p).convert('RGB')
images.append(image)
else:
image = Image.open(args.path).convert('RGB')
images.append(image)
elif args.base64:
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
images.append(image)
else:
raise "Specify a --path or --base64."
for image in images:
answer = classifier.inference( image=image, temperature=args.temp )
print("Answer:", answer)
if args.write:
args.write.mkdir(exist_ok=True)
image.save( args.write / f"{answer}.jpg")
if __name__ == "__main__":
main()

View File

@ -6,31 +6,61 @@ import os
import subprocess
import sys
import time
from dataclasses import asdict, dataclass
from dataclasses import dataclass, field
from functools import cached_property, cache
from pathlib import Path
from omegaconf import OmegaConf
import argparse
import yaml
import random
import logging
import torch
import numpy as np
from dataclasses import asdict, dataclass, field
from functools import cached_property
from pathlib import Path
from .utils.distributed import world_size
def set_seed(seed=None):
if not seed:
seed = time.time()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@dataclass()
class _Config:
cfg_path: str | None = None
class BaseConfig:
yaml_path: str | None = None # path passed in through --yaml
@property
def relpath(self):
def cfg_path(self):
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
@property
def rel_path(self):
return Path(self.cfg_path)
@property
def cache_dir(self):
return self.rel_path / ".cache"
@property
def data_dir(self):
return self.rel_path / "data"
@property
def metadata_dir(self):
return self.rel_path / "metadata"
@property
def ckpt_dir(self):
return self.relpath / "ckpt"
return self.rel_path / "ckpt"
@property
def log_dir(self):
return self.relpath / "logs" / str(self.start_time)
return self.rel_path / "logs" / str(self.start_time)
@cached_property
def start_time(self):
@ -64,39 +94,28 @@ class _Config:
with open(path, "w") as f:
f.write(self.dumps())
@staticmethod
def _is_cfg_argv(s):
return "=" in s and "--" not in s
@classmethod
def from_yaml( cls, yaml_path ):
return cls.from_cli( [f'yaml="{yaml_path}"'] )
state = {}
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
state.setdefault("yaml_path", yaml_path)
return cls(**state)
@classmethod
def from_cli(cls, args=sys.argv):
cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)])
# legacy support for yaml=`` format
for i, arg in enumerate(args):
if arg.startswith("yaml"):
args[i] = f'--{arg}'
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)]
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
args, unknown = parser.parse_known_args(args=args)
if cli_cfg.get("help"):
print(f"Configurable hyperparameters with their default values:")
print(json.dumps(asdict(cls()), indent=2, default=str))
exit()
if args.yaml:
return cls.from_yaml( args.yaml )
if "yaml" in cli_cfg:
yaml_cfg = OmegaConf.load(cli_cfg.yaml)
yaml_path = Path(cli_cfg.yaml).absolute()
cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1])
cfg_path = cfg_path.with_suffix("")
cfg_path = f'./{cfg_path}'
yaml_cfg.setdefault("cfg_path", cfg_path)
cli_cfg.pop("yaml")
else:
yaml_cfg = {}
merged = OmegaConf.merge(yaml_cfg, cli_cfg)
return cls(**dict(merged))
return cls(**{})
def __repr__(self):
return str(self)
@ -106,104 +125,195 @@ class _Config:
@dataclass()
class Dataset:
training: list[Path] = field(default_factory=lambda: [])
validation: list[Path] = field(default_factory=lambda: [])
training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
temp: list[Path] = field(default_factory=lambda: [])
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
use_hdf5: bool = False # whether to load from an HDF5 dataset
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
# de-implemented, because the data isn't that large to facilitate HDF5
hdf5_name: str = "data.h5"
use_hdf5: bool = False
workers: int = 8
cache: bool = True
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
workers: int = 8 # number of dataloader workers to spawn
cache: bool = True # use diskcache to cache the dataset
# I really need to clean this up
@dataclass()
class Model:
name: str = ""
name: str = "classifier"
tokens: int = 0 # number of token types
len: int = 1 # how long a sequence can be
dim: int = 512
resnet: int = 18
width: int = 300
height: int = 80
version: int = 1
training: bool = True
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
@property
def full_name(self):
return self.name
@dataclass()
class Models:
_models: list[Model] = field(default_factory=lambda: [
Model(name="captcha"),
])
def get(self, name=None):
if not name:
return [ Model(**model) for model in self._models ]
return [ self ] if not name or self.name == name else []
for model in self._models:
if model.name == name:
return model
def loss_factor(self, k):
return self.loss_factors[k] if k in self.loss_factors else 1.0
raise ValueError
@property
# required for fp8 as the lengths needs to be divisible by 8
def input_alignment(self):
return 8 if cfg.optimizations.fp8 else 0
@property
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing
@property
def gradient_checkpointing(self):
return cfg.trainer.gradient_checkpointing
@property
def lora_policy(self):
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
exclude = []
if self.arch_type == "llama":
include = ["self_attn", "mlp"] # target only the attention + mlp
exclude = ["self_attn.k_proj"] # common literature says to ignore it
if self.arch_type == "retnet":
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
return dict(include=include, exclude=exclude)
# should be renamed to Adapters
@dataclass()
class LoRA:
name: str = "lora" # vanity name
# to-do: find sane default values
rank: int = 128 # rank for the LoRA
alpha: int = 128 # rank for the LoRA
training: bool = True #
embeddings: bool = False # train the embedding too
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
@property
def full_name(self):
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
return "-".join(name)
# actually not needed anymore
def active_level( self, level ):
if not self.rvq_levels:
return True
return level in self.rvq_levels
@dataclass()
class Hyperparameters:
batch_size: int = 8
gradient_accumulation_steps: int = 32
gradient_clipping: int = 100 # to be implemented in the local backend
batch_size: int = 8 # number of samples per training batch
gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
gradient_clipping: int | float = 10 # largest size a gradient norm can be
optimizer: str = "Adamw"
learning_rate: float = 3.25e-4
optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
scheduler_type: str = "" # to be implemented in the local backend
scheduler_params: dict = field(default_factory=lambda: {})
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
scheduler_type: str = "" # deprecated
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
autotune: bool = False # to do deepspeed's autotuning
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
@dataclass()
class Evaluation:
batch_size: int = 64
frequency: int = 250
size: int = 64
batch_size: int = 64 # number of samples per batch during eval / val
frequency: int = 250 # do eval / val every X iterations
size: int = 64 # number of samples to generate during eval / val
steps: int = 500
temperature: float = 1.0
temperature: float = 1.0 # AR temp for inferencing
load_disabled_engines: bool = True # see the other load_disabled_engines
@dataclass()
class DeepSpeed:
zero_optimization_level: int = 0
use_compression_training: bool = False
zero_optimization_level: int = 0 # doesn't seem to work
use_compression_training: bool = False # cope
compression_bits: int = 8 # cope
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
def get_ds_cfg(self, model):
weights = [ name[0] for name in model.named_parameters() ]
bits = 8
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
scheduler_params = {}
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
@cached_property
def ds_cfg(self):
optimizer_params = cfg.hyperparameters.optimizer_params
if 'lr' not in optimizer_params:
optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
scheduler_params = cfg.hyperparameters.scheduler_params
if 'warmup_num_steps' not in scheduler_params:
scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
if 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations
autotune_params = cfg.hyperparameters.autotune_params
if "enabled" not in autotune_params:
autotune_params['enabled'] = True
if "results_dir" not in autotune_params:
autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" )
if "exps_dir" not in autotune_params:
autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" )
# DeepSpeed fp16 is incompatible with its AMP
if cfg.trainer.weight_dtype.lower() == "float16":
self.amp = False
# disable local AMP
if self.amp:
cfg.trainer.amp = False
ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
"optimizer": {
"type": cfg.hyperparameters.optimizer,
"params": {
"lr": cfg.hyperparameters.learning_rate,
}
},
"params": optimizer_params,
} if not cfg.hyperparameters.torch_optimizer else None,
"scheduler": {
"type": cfg.hyperparameters.scheduler_type,
"type": cfg.hyperparameters.scheduler,
"params": scheduler_params,
} if cfg.hyperparameters.scheduler_type != "" else None,
} if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
"enabled": True,
"auto_cast": True,
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": True, # ???
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
},
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
},
"amp": {
"enabled": self.amp,
},
"autotuning": autotune_params if cfg.hyperparameters.autotune else None,
"compression_training": {
"weight_quantization": {
"shared_parameters":{
@ -214,7 +324,7 @@ class DeepSpeed:
"quantize_verbose": True,
"quantization_type": "symmetric",
"rounding": "nearest",
"quantize_weight_in_forward": True,
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
@ -223,30 +333,38 @@ class DeepSpeed:
"different_groups": {
"wq1": {
"params": {
"start_bits": bits,
"target_bits": bits,
"start_bits": self.compression_bits,
"target_bits": self.compression_bits,
"quantization_period": 0
},
"modules": weights
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
}
}
},
"activation_quantization": {
"shared_parameters":{
"enabled": True,
"quantizer_kernel": True,
"schedule_offset": 0,
"quantize_groups": 64,
"quantize_verbose": True,
"quantization_type": "symmetric",
"range_calibration": "dynamic",
"schedule_offset": 0
"rounding": "nearest",
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
}
},
"different_groups": {
"aq1": {
"params": {
"bits": bits
"bits": self.compression_bits,
},
"modules": weights
}
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
}
}
},
} if self.use_compression_training else None,
"zero_optimization": {
"stage": self.zero_optimization_level,
@ -264,7 +382,10 @@ class DeepSpeed:
"offload_param": {
"device": "cpu",
"pin_memory": True
}
},
"zero_quantized_weights": self.use_compression_training,
"zero_hpz_partition_size": world_size(),
"zero_quantized_gradients": self.use_compression_training,
} if self.zero_optimization_level > 0 else None,
"comms_logger": {
"enabled": False
@ -275,113 +396,314 @@ class DeepSpeed:
for k in null_keys:
del ds_cfg[k]
if os.path.exists("./config/ds_config.json"):
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
else:
ds_cfg.update(self.config)
return ds_cfg
@dataclass()
class Trainer:
iterations: int = 100_000
iterations: int = 1_000_000 # maximum iterations to train
save_tag: str = "step"
load_tag: str | None = None
save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
save_on_oom: bool = True
save_on_quit: bool = True
save_frequency: int = 100
save_on_oom: bool = True # save if an OOM error is raised
save_on_quit: bool = True # save when quitting training
load_state_dict: bool = False
load_states: bool = True
strict_loading: bool = True
restart_step_count: bool = False
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
aggressive_optimizations: bool = False
check_for_oom: bool = True
save_frequency: int = 100 # frequency to save every X iterations
gc_mode: str | None = None
keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
weight_dtype: str = "float16"
load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
load_states: bool = True #
strict_loading: bool = False # sets strict_loading=True when loading the state dict
load_module_only: bool = False #
restart_step_count: bool = False # clears the training stats when loading a checkpoint
resize_modules: bool = False # automatically resizes
backend: str = "deepspeed"
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
aggressive_optimizations: bool = False # deprecated
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
weight_dtype: str = "float16" # dtype to have the model under
amp: bool = False # automatic mixed precision
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if cfg.trainer.weight_dtype == "bfloat16":
if self.weight_dtype == "bfloat16":
return torch.bfloat16
if self.weight_dtype == "float8_e5m2":
return torch.float8_e5m2
if self.weight_dtype == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32
@cached_property
def scale_loss(self):
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
return self.dtype == torch.float16
@dataclass()
class Inference:
use_vocos: bool = True # artifact from the VALL-E trainer
backend: str = "local" # backend to use when inferencing
weight_dtype: str = "float32" # dtype to load the model under
amp: bool = False # automatic mixed precision during inferencing
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
@property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
if self.weight_dtype == "int8":
return torch.int8
if self.weight_dtype == "float8_e5m2":
return torch.float8_e5m2
if self.weight_dtype == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32
@dataclass()
class BitsAndBytes:
enabled: bool = False
injects: bool = False
class Optimizations:
injects: bool = False # overwrites default torch classes (not recommended)
replace: bool = False # replaces modules in place with the optimized version (recommended)
compile: bool | str = False # runs torch.compile on the model
linear: bool = False
embedding: bool = False
linear: bool = True # inject/replace linear for BnB
embedding: bool = True # inject/replace embedding for BnB
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitsandbytes: bool = False # use bitsandbytes
dadaptation: bool = False # use dadaptation optimizer
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
model_offloading: dict | None = None # automatically splits the model over a list of devices
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
tensorrt: bool = False
@dataclass()
class Config(_Config):
device: str = "cuda"
class Config(BaseConfig):
device: str = "cuda" # target device
mode: str = "training" # "inferencing"
experimental: bool = False # Debug flag, unused now
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)
models: dict | list | None = field(default_factory=lambda: [])
loras: dict | list | None = field(default_factory=lambda: [])
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
def get_device(self):
return torch.cuda.current_device() if self.device == "cuda" else self.device
tokenizer: str | None = None # tokenizer class
tokenizer_path: str = "./tokenizer.json" # tokenizer path
weights_format: str = "pth" # "pth" | "sft"
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
@property
def cache_dir(self):
return ".cache" / self.relpath
def model(self):
for i, model in enumerate(self.models):
if model.training:
return model
return self.models[0] if len(self.models) > 0 else None
# should be renamed to adapters
@property
def lora(self):
for i, lora in enumerate(self.loras):
if lora.training:
return lora
return self.loras[0] if len(self.loras) > 0 else None
@property
def distributed(self):
return world_size() > 1
@cached_property
def diskcache(self):
if self.dataset.cache:
if self.yaml_path is not None and self.dataset.cache:
return diskcache.Cache(self.cache_dir).memoize
return lambda: lambda x: x
# I don't remember why this is needed
def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__)
def load_hdf5( self, write=False ):
if hasattr(self, 'hdf5'):
self.hdf5.close()
if self.distributed:
self.dataset.hdf5_flag = "r"
try:
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
self.dataset.use_hdf5 = False
# to-do: prune unused keys
def format( self, training=True ):
if isinstance(self.dataset, type):
self.dataset = dict()
if isinstance(self.models, type):
self.models = dict()
if isinstance(self.loras, type):
self.loras = dict()
if isinstance(self.hyperparameters, type):
self.hyperparameters = dict()
if isinstance(self.evaluation, type):
self.evaluation = dict()
if isinstance(self.trainer, type):
self.trainer = dict()
if isinstance(self.inference, type):
self.inference = dict()
if isinstance(self.optimizations, type):
self.optimizations = dict()
self.dataset = Dataset(**self.dataset)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.models = [ Model(**model) for model in self.models ]
self.loras = [ LoRA(**lora) for lora in self.loras ]
if not self.models:
self.models = [ Model() ]
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer)
if not isinstance(self.trainer.deepspeed, type):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
self.inference = Inference(**self.inference)
if self.bitsandbytes is not None:
self.optimizations = Optimizations(**self.bitsandbytes)
else:
self.optimizations = Optimizations(**self.optimizations)
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
self.hyperparameters.scheduler_type = ""
# do not combine the two
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
self.hyperparameters.scheduler = ""
if self.hyperparameters.scheduler == "":
self.hyperparameters.torch_scheduler = True
if self.trainer.backend == "local" and self.distributed:
self.trainer.ddp = True
if self.trainer.activation_checkpointing is not None:
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
if not training:
self.dataset.use_hdf5 = False
# load our HDF5 file if requested here
if self.dataset.use_hdf5:
self.load_hdf5()
# load tokenizer
if cfg.tokenizer == "naive":
cfg.tokenizer = NaiveTokenizer()
else:
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
if tokenizer_path and not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer_path
if tokenizer_path and tokenizer_path.exists():
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
else:
cfg.tokenizer = NaiveTokenizer()
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
pass
# Preserves the old behavior
class NaiveTokenizer:
def get_vocab( self ):
"""
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
"""
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
@cached_property
def _bos_token( self ):
return self.get_vocab()["<s>"]
@cached_property
def _eos_token( self ):
return self.get_vocab()["</s>"]
def encode( self, s ):
symmap = self.get_vocab()
s = s.replace("O", "0")
s = [f"<s>"] + [ p if p in symmap else " " for p in s ] + [f"</s>"]
return [*map(symmap.get, s)]
_logger = logging.getLogger(__name__)
cfg = Config.from_cli()
# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves
cfg.dataset = Dataset(**cfg.dataset)
cfg.models = Models(**cfg.models)
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
cfg.evaluation = Evaluation(**cfg.evaluation)
cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
# cached_property stopped working...
if cfg.dataset.use_hdf5:
# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
cfg.format()
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
_logger.error(f"Error while parsing config YAML: {str(e)}")
raise e # throw an error because I'm tired of silent errors messing things up for me
if __name__ == "__main__":
print(cfg)

View File

@ -1,16 +1,19 @@
# todo: clean this mess up
import copy
# import h5py
import h5py
import json
import logging
#import numpy as np
import numpy as np
import os
import random
import torch
import math
import itertools
from .config import cfg
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from .utils.io import torch_save, torch_load
from collections import defaultdict
from functools import cache, cached_property
@ -20,23 +23,57 @@ from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from PIL import Image
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
@cache
# to-do: clean up this symmap mess
def get_symmap():
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
return cfg.tokenizer.get_vocab()
@cache
def _get_symbols( content ):
content = content.replace("O", "0")
return [f"<s>"] + [ p for p in content ] + [f"</s>"]
def tokenize( s ):
if isinstance( s, list ):
s = "".join( s )
return cfg.tokenizer.encode( s )
"""
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_hdf5_path(path):
# to-do: better validation
return str(path)
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
key = f"/{type}/{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, validate=False ):
if isinstance(path, str):
path = Path(path)
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
groups[fn(e)].append(e)
groups = {k: groups[k] for k in sorted(groups)}
for interleaved in zip_longest(*groups.values()):
for value in interleaved:
if value is not None:
yield value
"""
class Dataset(_Dataset):
def __init__(
@ -44,43 +81,90 @@ class Dataset(_Dataset):
paths,
width=300,
height=80,
stacks=0,
symmap=get_symmap(),
training=False,
):
super().__init__()
self._head = None
self.paths = paths
self.sampler = None
self.width = width
self.height = height
self.stacks = stacks
self.paths = paths
self.image_dtype = cfg.trainer.dtype
self.symmap = symmap
self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.transform = transforms.Compose([
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@cached_property
def symbols(self):
return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths]))
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len(self.dataset) == 0:
self.dataset = cfg.dataset.training
# split dataset accordingly per GPU
if cfg.distributed and self.training:
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
@cached_property
def sampler_state_dict_path(self):
return cfg.rel_path / f"sampler.rank{global_rank()}.pt"
def save_state_dict(self, path = None):
"""
if path is None:
path = self.sampler_state_dict_path
if self.sampler is not None:
state_dict = self.sampler.get_state()
elif self.samplers is not None:
state_dict = {
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
}
torch_save(state_dict, path)
"""
return
def load_state_dict(self, path = None):
"""
if path is None:
path = self.sampler_state_dict_path
if not path.exists():
return
state_dict = torch_load(path)
if self.sampler is not None:
state_dict = self.sampler.set_state(state_dict)
else:
for name, sampler in state_dict["samplers"].items():
if name not in self.samplers:
continue
self.samplers[name].set_state( sampler )
"""
return
def __getitem__(self, index):
path = self.paths[index]
tokens = tokenize( path.stem.upper() )
text = torch.tensor( tokens ).to(dtype=torch.uint8)
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
try:
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
except Exception as e:
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
raise e
image = Image.open(path).convert('RGB')
width, height = image.size
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
return dict(
index=index,
@ -98,11 +182,6 @@ class Dataset(_Dataset):
def __len__(self):
return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self):
self.text = self.text.pin_memory()
self.image = self.image.pin_memory()
return self
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
@ -111,21 +190,28 @@ def collate_fn(samples: list[dict]):
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
#np.random.seed(worker_seed)
np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dataloader(dataset, training):
kwargs = dict(
shuffle=True,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
drop_last=training,
sampler=dataset.sampler if training else None,
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
batch_sampler=dataset.sampler,
)
return DataLoader(
dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=True, # training
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 0,
pin_memory=False, # True,
persistent_workers=cfg.dataset.workers > 1,
pin_memory=False,
worker_init_fn=_seed_worker,
**kwargs,
)
def _load_train_val_paths( val_ratio=0.1 ):
@ -133,8 +219,8 @@ def _load_train_val_paths( val_ratio=0.1 ):
train_paths = []
val_paths = []
print(cfg.dataset.training)
for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
@ -146,12 +232,13 @@ def _load_train_val_paths( val_ratio=0.1 ):
val_len = math.floor(len(train_paths) * val_ratio)
train_len = math.floor(len(train_paths) * (1 - val_ratio))
print(val_len, train_len)
val_paths = train_paths[:-val_len]
train_paths = train_paths[:train_len]
else:
paths = []
for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
@ -169,7 +256,6 @@ def _load_train_val_paths( val_ratio=0.1 ):
return train_paths, val_paths
@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
@ -187,10 +273,10 @@ def create_datasets():
return train_dataset, val_dataset
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
# deepcopy is slow
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.training_(False)
@ -200,8 +286,6 @@ def create_train_val_dataloader():
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
@ -210,11 +294,305 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
# parse dataset into better to sample metadata
"""
def create_dataset_metadata( skip_existing=True ):
symmap = get_symmap()
root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir)
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
def add( dir, type="training", audios=True, texts=True ):
name = str(dir)
name = name.replace(root, "")
speaker_name = name
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
try:
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
except Exception as e:
metadata = {}
if not os.path.isdir(f'{root}/{name}/'):
return
# tqdm.write(f'{root}/{name}')
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
wrote = False
for id in tqdm(ids, desc=f"Processing {name}"):
try:
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
if audios and not quant_path.exists():
continue
key = f'{type}/{speaker_name}/{id}'
if skip_existing and id in metadata:
continue
wrote = True
if id not in metadata:
metadata[id] = {}
utterance_metadata = {}
if audios:
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np.load(quant_path, allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
if "text" in dac["metadata"]:
utterance_metadata["text"] = dac["metadata"]["text"]
if "phonemes" in dac["metadata"]:
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
if "language" in dac["metadata"]:
utterance_metadata["language"] = dac["metadata"]["language"]
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
for k, v in utterance_metadata.items():
metadata[id][k] = v
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
if wrote:
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
# parse yaml to create an hdf5 file
def create_dataset_hdf5( skip_existing=True ):
cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True)
hf = cfg.hdf5
symmap = get_symmap()
root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir)
def add( dir, type="training", audios=True, texts=True ):
name = str(dir)
name = name.replace(root, "")
# yucky
speaker_name = name
if "LibriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
if not os.path.isdir(f'{root}/{name}/'):
return
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
for id in tqdm(ids, desc=f"Processing {name}"):
try:
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not quant_exists:
continue
key = f'{type}/{speaker_name}/{id}'
if skip_existing and key in hf:
continue
group = hf.create_group(key) if key not in hf else hf[key]
if id not in metadata:
metadata[id] = {}
utterance_metadata = {}
# audio
if audios:
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
if "text" in dac["metadata"]:
utterance_metadata["text"] = dac["metadata"]["text"]
if "phonemes" in dac["metadata"]:
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
if "language" in dac["metadata"]:
utterance_metadata["language"] = dac["metadata"]["language"]
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
if "audio" not in group:
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
# text
if texts:
if not utterance_metadata and text_exists:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
phn = "".join(utterance_metadata["phonemes"])
phn = cfg.tokenizer.encode(phn)
phn = np.array(phn).astype(np.uint8)
if "text" not in group:
group.create_dataset('text', data=phn, compression='lzf')
for k, v in utterance_metadata.items():
group.attrs[k] = v
metadata[id][k] = v
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
# write symmap
if "symmap" in hf:
del hf['symmap']
hf.create_dataset('symmap', data=json.dumps(symmap))
hf.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--action", type=str)
parser.add_argument("--tasks", type=str)
args, unknown = parser.parse_known_args()
task = args.action
cfg.dataset.workers = 1
if args.action == "hdf5":
create_dataset_hdf5()
elif args.action == "list-dataset":
dataset = []
for group in os.listdir(cfg.data_dir):
for name in os.listdir(cfg.data_dir / group):
if len(os.listdir(cfg.data_dir / group / name)) == 0:
continue
dataset.append(f'{group}/{name}')
_logger.info(json.dumps(dataset))
elif args.action == "metadata":
create_dataset_metadata()
elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
for k, v in samples.items():
for i in range(len(v)):
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e:
_logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items():
for i in range(len(v)):
_logger.info(f'{k}[{i}]: {v[i]}')
elif args.action == "validate":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
missing = set()
for i in range(len( train_dl.dataset )):
batch = train_dl.dataset[i]
text = batch['text']
phonemes = batch['metadata']['phonemes']
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
for i, token in enumerate(decoded):
if token != "<unk>":
continue
phone = phonemes[i]
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
missing |= set([phone])
_logger.info( f"Missing tokens: {missing}" )
elif args.action == "tasks":
index = 0
cfg.dataset.tasks_list = args.tasks.split(",")
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)
batch = next(iter(train_dl))
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
if task not in cfg.dataset.tasks_list:
continue
_logger.info( f'{text} {task} {cfg.model.resp_levels}')
_logger.info( f'{proms.shape} {resps.shape}' )
tokens = 0
tokens += sum([ text.shape[0] for text in batch["text"] ])
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
_logger.info( f'{tokens}' )
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
break
"""

View File

@ -1,6 +1,6 @@
from ..config import cfg
from ..utils.distributed import fix_unset_envs
from ..utils.distributed import fix_unset_envs, ddp_model
fix_unset_envs()
if cfg.trainer.backend == "deepspeed":
@ -8,4 +8,211 @@ if cfg.trainer.backend == "deepspeed":
elif cfg.trainer.backend == "local":
from .base import Engine
from .base import Engines, TrainFeeder, default_feeder
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models, get_model
from ..utils import wrapper as ml
from ..utils.io import torch_save, torch_load, pick_path
from ..models.lora import apply_lora, lora_load_state_dict
import torch
import re
import logging
_logger = logging.getLogger(__name__)
deepspeed_available = False
try:
import deepspeed
deepspeed_available = True
except Exception as e:
pass
from functools import cache
@cache
def load_engines(training=True, **model_kwargs):
models = get_models(cfg.models, training=training, **model_kwargs)
engines = dict()
for name, model in models.items():
state = None
stats = None
lora = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
# to handle the issue of training with deepspeed, but inferencing with local
if checkpoint_path.exists() and backend == "local":
tag = open(checkpoint_path).read()
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
loads_state_dict = True
# load state early
if loads_state_dict:
state = torch_load(load_path, device=cfg.device)
# check if config is defined in state, and re-initialize the model
if "config" in state and False:
_logger.warning("Model config definition in weights, re-loading...")
config_state = state["config"]
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
hyper_config = model.config
optimizer = None
lr_scheduler = None
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
ddp = cfg.trainer.ddp
engine_class = LocalEngine if backend == "local" else Engine
# apply model replacers
if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model.model = ml.replace_embedding( model.model )
for lora in cfg.loras:
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
if inferencing:
model.config.training = False
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
optimizer_class = None
scheduler_class = None
params = {
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw":
params["betas"] = (0.9, 0.96)
params["eps"] = 1e-07
params["weight_decay"] = 0.01
# for dadaptation since it has Adam only
if ml.AdamW == ml.Adam:
params["decouple"] = True
optimizer_class = ml.AdamW
elif cfg.hyperparameters.optimizer.lower() == "sgd":
optimizer = ml.SGD
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
optimizer_class = ml.Prodigy
params['d_coef'] = params['lr']
params['lr'] = 1.0
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
optimizer_class = ml.Adagrad
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
**params,
)
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
if cfg.hyperparameters.optimizer.lower() == "adamw":
scheduler_class = ml.schedulefree.AdamWScheduleFree
elif cfg.hyperparameters.optimizer.lower() == "sgd":
scheduler_class = ml.schedulefree.SGDScheduleFree
else:
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
optimizer = scheduler_class(
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
lr = params['lr'],
warmup_steps = cfg.hyperparameters.warmup_steps
)
"""
# set up our LR scheduler here
"""
if inferencing:
optimizer = None
lr_scheduler = None
# load state dict if requested / required
if loads_state_dict:
# state dict is not just the module, extract the extra trainer details
if "stats" in state:
stats = state["stats"]
# do not load stats if we're training a LoRA
if cfg.lora is not None or cfg.trainer.restart_step_count:
stats = None
if "module" in state:
state = state["module"]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# load lora weights if exists
if cfg.lora is not None:
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if lora_path.exists():
_logger.info( f"Loaded LoRA state dict: {lora_path}" )
state = torch_load(lora_path, device=cfg.device)
state = state['lora' if 'lora' in state else 'module']
lora_load_state_dict( model, state )
# wrap if DDP is requested
if ddp:
model = ddp_model(model)
# wrap optimization class
elif cfg.optimizations.compile:
model = ml.compile_model(model, backend=cfg.optimizations.compile)
# deepspeed inferencing
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = LocalEngine
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
# use base engine if requested
engines[name] = engine_class(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
hyper_config=hyper_config,
stats=stats
)
engines = Engines(engines)
engines.setup()
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
if not cfg.trainer.load_state_dict:
engines.load_checkpoint(training=not inferencing)
# freeze requested params
for name, engine in engines.items():
engine.freeze(freeze_all=False)
# split models over requested devices
if cfg.optimizations.model_offloading:
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
return engines

View File

@ -28,7 +28,9 @@ def default_feeder(engine, batch):
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
from ..utils.io import torch_save, torch_load
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
import logging
import time
@ -39,40 +41,65 @@ import os
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from functools import cached_property
from .base import TrainFeeder
from ..utils import wrapper as ml
_logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local":
def _nop():
...
fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group
init_distributed(fn)
if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch
# to-do: implement lr_sheduling
class Engine():
def __init__(self, *args, **kwargs):
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
if 'hyper_config' in kwargs:
self.hyper_config = kwargs['hyper_config']
kwargs.pop("hyper_config")
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = 0
self.micro_steps = 0
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
self.global_steps = kwargs.pop("global_steps", 0)
self.micro_steps = kwargs.pop("micro_steps", 0)
self.global_samples = kwargs.pop("global_samples", 0)
self.tokens_processed = kwargs.pop("tokens_processed", 0)
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
self._frozen_params = set()
self.max_nan_losses = 8
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
self.current_batch_size = 0
self._global_grad_norm = None
def freeze(self, freeze_all=True):
# set to freeze
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def _training(self):
if not hasattr(self, "hyper_config"):
return True
return self.hyper_config.training
@property
def global_step(self):
return self.global_steps
@ -81,8 +108,17 @@ class Engine():
def micro_step(self):
return self.micro_steps
def train_batch_size(self):
return cfg.hyperparameters.batch_size
@property
def batch_size(self):
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
@property
def gradient_accumulation_steps(self):
return cfg.hyperparameters.gradient_accumulation_steps
@property
def gradient_clipping(self):
return cfg.hyperparameters.gradient_clipping
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
@ -91,42 +127,74 @@ class Engine():
return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ):
save_path = save_dir / tag / "state.pth"
if is_global_leader():
module = self.module.state_dict()
# if training lora
# this is a separate path to override saving the weights
lora = None
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_dir = cfg.ckpt_dir / cfg.lora.full_name
save_path = save_dir / tag / f"state.{cfg.weights_format}"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"global_step": self.global_step,
"micro_step": self.micro_step,
"module": self.module.state_dict(),
torch_save({
"module": module,
"lora": lora,
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
"stats": {
"global_step": self.global_step,
"micro_step": self.micro_step,
"global_samples": self.global_samples,
"tokens_processed": self.tokens_processed,
}
}, save_path)
open(save_dir / "latest", 'w').write( tag )
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
torch.distributed.barrier()
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
if tag is None:
tag_path = load_dir / "latest"
if not tag_path.exists():
return
tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth"
load_path = load_dir / tag / f"state.{cfg.weights_format}"
if not load_path.exists():
return
state = torch.load(load_path)
self.global_steps = state['global_step']
self.micro_steps = state['micro_step']
self.module.load_state_dict(state['module'])
state = torch_load(load_path, device=cfg.device)
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer'])
self.optimizer.load_state_dict(state['optimizer']) #, device=cfg.device)
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device)
if 'lora' in state:
lora_load_state_dict( self.module, state['lora'] )
def eval(self):
return self.module.eval()
@ -136,46 +204,80 @@ class Engine():
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
return self.module
if self.optimizer:
self.optimizer = self.optimizer.to(*args, **kwargs)
return self
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@cached_property
def device(self):
return next(self.module.parameters()).device
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)
def backward(self, loss):
if self.loss_scaler is not None:
return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward()
return (loss / self.gradient_accumulation_steps).backward()
def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1
self.global_samples += self.batch_size
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
self.global_steps += 1
if self.loss_scaler is not None:
self.loss_scaler.step(self.optimizer)
self.loss_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self._get_grad_norm()
def _get_grad_norm(self):
t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
if 'd_coeff' in param_group:
lrs.append(param_group['d_coeff'])
elif 'lr' in param_group:
lrs.append(param_group['lr'])
return lrs
def set_lr(self, lr):
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
if 'd_coeff' in param_group:
param_group['d_coeff'] = lr
elif 'lr' in param_group:
param_group['lr'] = lr
def get_global_grad_norm(self):
return 0.0
return self._global_grad_norm
def traverse(self, *args, **kwargs):
with ml.autocast():
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
@ -194,6 +296,8 @@ class Engines(dict[str, Engine]):
def setup(self):
self._global_step = 0
self._micro_step = 0
self._batch_size = 0
self._global_samples = 0
@property
def global_step(self):
@ -203,6 +307,14 @@ class Engines(dict[str, Engine]):
def micro_step(self):
return self._micro_step
@property
def batch_size(self):
return self._batch_size
@property
def global_samples(self):
return self._global_samples
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
@ -213,6 +325,50 @@ class Engines(dict[str, Engine]):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}, callback=None, dtype=None, format=None):
if not format:
format = cfg.weights_format
format = format.lower()
if dtype is None:
dtype = cfg.trainer.dtype
for name, engine in self.items():
module = engine.module.state_dict()
lora = None
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
# safety
for k, v in module.items():
module[k] = v.to(dtype)
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
state_dict = {
'module': module,
'lora': lora,
"stats": {
"global_step": engine.global_step,
"micro_step": engine.micro_step,
"global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
},
"userdata": userdata,
"config": config.__dict__
}
if lora is None:
del state_dict['lora']
if callback:
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
torch_save(state_dict, save_path)
_logger.info(f"Exported {name} to {save_path}")
def save_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.save_tag
@ -222,47 +378,67 @@ class Engines(dict[str, Engine]):
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag)
if not engine._training:
continue
def load_checkpoint(self, tag=None):
save_dir = cfg.ckpt_dir / name
try:
engine.save_checkpoint(save_dir, tag=tag)
except Exception as e:
_logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
checkpoints = [ d for d in list(save_dir.glob("*")) if d.is_dir() ]
checkpoints.sort(key=lambda x: x.stat().st_mtime)
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
for d in checkpoints:
if not d.is_dir() or not d.exists():
continue
_logger.info(f"Removing {d}")
for p in d.iterdir():
p.unlink()
d.rmdir()
def load_checkpoint(self, tag=None, training=True):
if not tag:
tag = cfg.trainer.load_tag
for name, engine in self.items():
load_dir = cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=cfg.trainer.load_states,
load_lr_scheduler_states=cfg.trainer.load_states,
load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
load_module_only=cfg.trainer.load_module_only,
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0
engine.mocro_step = 0
engine.global_samples = 0
engine.tokens_processed = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg.hyperparameters.scheduler_type == "":
self.set_lr(cfg.hyperparameters.learning_rate)
self._update_global_step()
self._update_micro_step()
self._update()
def set_lr(self, lr):
for engine in self.values():
if not engine._training:
continue
engine.set_lr(lr)
def _update_global_step(self):
def _update(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def _update_micro_step(self):
for engine in self.values():
self._micro_step = max(self._micro_step, engine.micro_step)
def train_batch_size(self):
batch_size = 0
for engine in self.values():
batch_size = max(batch_size, engine.train_batch_size())
self._batch_size = max(self._batch_size, engine.batch_size)
self._global_samples = max(self._global_samples, engine.global_samples)
def eval(self):
for engine in self.values():
@ -279,7 +455,10 @@ class Engines(dict[str, Engine]):
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()):
def quit(self):
cleanup_distributed()
def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0
stats: Any = dict()
@ -287,10 +466,11 @@ class Engines(dict[str, Engine]):
if cfg.trainer.gc_mode == 'step':
do_gc()
batch = to_device(batch, device)
for name, engine in self.items():
#torch.cuda.synchronize()
if not engine._training:
continue
device = engine.device
if cfg.trainer.gc_mode == 'substep':
do_gc()
@ -298,9 +478,8 @@ class Engines(dict[str, Engine]):
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=cfg.device)
n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, device)
if not cfg.trainer.check_for_oom:
@ -311,7 +490,7 @@ class Engines(dict[str, Engine]):
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
print("Forward", str(e))
_logger.error(f"Forward: {str(e)}")
if "out of memory" not in str(e):
self.save_checkpoint()
@ -329,6 +508,7 @@ class Engines(dict[str, Engine]):
do_gc()
continue
if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
@ -340,7 +520,7 @@ class Engines(dict[str, Engine]):
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=cfg.device)
n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
@ -348,10 +528,11 @@ class Engines(dict[str, Engine]):
if not cfg.trainer.check_for_oom:
engine.backward(loss)
else:
# to-do: properly handle when one GPU throws an OOM because it just halts
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
_logger.error(f"Backwards: {str(e)}")
if "out of memory" not in str(e):
self.save_checkpoint()
@ -359,9 +540,12 @@ class Engines(dict[str, Engine]):
n_ooms += 1
if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
engine.step()
@ -370,27 +554,36 @@ class Engines(dict[str, Engine]):
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
grad_norm = engine.get_global_grad_norm()
loss_scale = 1
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
loss_scale = engine.optimizer.loss_scale
if grad_norm is not None:
grad_norm /= loss_scale
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
**engine_stats,
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
loss_scale=loss_scale if loss_scale != 1 else None,
elapsed_time=elapsed_time,
engine_step=engine.global_step,
**engine_stats,
samples_processed=engine.global_samples,
tokens_processed=engine.tokens_processed,
)
}
),
)
self._update_global_step()
self._update_micro_step()
stats["batch_size"] = self.train_batch_size() # len(batch["text"])
self._update()
if len(self.keys()) > 1:
stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
stats["it"] = self.global_step
return stats

View File

@ -25,29 +25,72 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr
from deepspeed.accelerator import get_accelerator
from ..utils.distributed import init_distributed, distributed_initialized
from ..utils import wrapper as ml
from ..models.lora import freeze_non_lora_weights
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
init_distributed(init_deepspeed_dist)
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
self.hyper_config = None
if 'hyper_config' in kwargs:
self.hyper_config = kwargs['hyper_config']
kwargs.pop("hyper_config")
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
stats = {
"global_step": 0,
"micro_step": 0,
"global_samples": 0,
"tokens_processed": 0,
}
# kwargs['stats'] = None will return None when popped
maybe_stats = kwargs.pop('stats', stats)
if maybe_stats is not None:
stats = maybe_stats
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
self.global_steps = stats["global_step"]
self.micro_steps = stats["micro_step"]
self.global_samples = stats["global_samples"]
self.tokens_processed = stats["tokens_processed"]
self.max_nan_losses = 8
self.current_batch_size = 0
def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for param in frozen_params:
self._frozen_params.add( param )
return
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
for param in self._frozen_params:
param.requires_grad_(True)
self._frozen_params.clear()
@property
def _training(self):
return self.hyper_config.training
@property
def global_step(self):
return self.global_steps
@ -56,6 +99,10 @@ class Engine(DeepSpeedEngine):
def micro_step(self):
return self.micro_steps
@property
def batch_size(self):
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
@ -66,17 +113,40 @@ class Engine(DeepSpeedEngine):
try:
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
else:
self.optimizer.set_lr(lr)
except Exception as e:
print(str(e))
_logger.warning(str(e))
# we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways
def load_checkpoint(self, load_dir, **kwargs ):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().load_checkpoint( load_dir, **kwargs )
def save_checkpoint(self, save_dir, **kwargs ):
# override to save the lora instead
if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().save_checkpoint( save_dir, **kwargs )
def traverse(self, *args, **kwargs):
with ml.autocast():
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")

View File

@ -1,31 +1,67 @@
import argparse
import torch
import torch.nn
from .data import get_symmap
from .train import load_engines
from .engines import load_engines
from .config import cfg
from .models.lora import lora_get_state_dict
from .utils.io import torch_save, torch_load
def load_models():
models = {}
engines = load_engines()
for name in engines:
model = engines[name].module.cpu()
models[name] = model
# yanks a LoRA from the training checkpoint
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None:
dtype = cfg.inference.dtype
return models
format = save_path.stem[1:]
lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case
if lora is None and "module" in state_dict:
lora, module = lora_get_state_dict( state_dict["module"], split = True )
state_dict["module"] = module
if "lora" in state_dict:
state_dict["lora"] = None
# should raise an exception since there's nothing to extract, or at least a warning
if not lora:
return state_dict
# save lora specifically
# should probably export other attributes, similar to what SD LoRAs do
save_path = save_path.parent / f"lora.{format}"
torch_save( {
"module": lora,
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
}, save_path )
return state_dict
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("path")
args = parser.parse_args()
parser.add_argument("--module-only", action='store_true')
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
args, unknown = parser.parse_known_args()
models = load_models()
for name in models:
model = models[name]
if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]:
raise Exception(f"Unknown requested format: {args.format}")
outpath = f'{args.path}/{name}.pt'
torch.save(model, outpath)
print(f"Exported {name} to {outpath}")
if args.module_only:
cfg.trainer.load_module_only = True
if args.dtype != "auto":
cfg.trainer.weight_dtype = args.dtype
# necessary to ensure we are actually exporting the weights right
cfg.inference.backend = cfg.trainer.backend
engines = load_engines(training=False) # to ignore loading optimizer state
callback = None
engines.export(userdata={"symmap": get_symmap()}, callback=callback, format=args.format)
if __name__ == "__main__":
main()

View File

@ -1,53 +1,103 @@
import torch
import torchaudio
import soundfile
import time
import logging
_logger = logging.getLogger(__name__)
from torch import Tensor
from einops import rearrange
from pathlib import Path
from .utils import to_device, set_seed, wrapper as ml
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
from .config import cfg
from .export import load_models
from .data import get_symmap, _get_symbols
from .config import cfg, Config
from .models import get_models
from .engines import load_engines, deepspeed_available
from .data import get_symmap, tokenize
if deepspeed_available:
import deepspeed
class Classifier():
def __init__( self, width=300, height=80, config=None, ckpt=None, device=cfg.get_device(), dtype="float32" ):
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
self.loading = True
# yes I can just grab **kwargs and forward them here
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
self.load_model()
self.loading = False
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
if config:
_logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config )
self.loading = True
try:
cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
except Exception as e:
raise e # throw an error because I'm tired of silent errors messing things up for me
if amp is None:
amp = cfg.inference.amp
if dtype is None or dtype == "auto":
dtype = cfg.inference.weight_dtype
if device is None:
device = cfg.device
cfg.device = device
cfg.mode = "inferencing"
cfg.trainer.backend = cfg.inference.backend
cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype
self.device = device
self.dtype = cfg.inference.dtype
self.amp = amp
if ckpt:
self.load_model_from_ckpt( ckpt )
else:
self.load_model_from_cfg( config )
self.model_kwargs = {}
self.model.eval()
def load_model( self ):
load_engines.cache_clear()
self.width = width
self.height = height
self.engines = load_engines(training=False, **self.model_kwargs)
for name, engine in self.engines.items():
if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.engines.eval()
self.symmap = get_symmap()
self.width = 300
self.height = 80
self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.loading = False
_logger.info("Loaded model")
def load_model_from_ckpt( self, ckpt ):
self.ckpt = ckpt
@torch.inference_mode()
def inference( self, image, temperature=1.0 ):
model = None
self.model = torch.load(self.ckpt).to(self.device)
def load_model_from_cfg( self, config_path ):
models = load_models()
for name in models:
model = models[name]
self.model = model.to(self.device)
for name, engine in self.engines.items():
model = engine.module
break
def inference( self, image, temperature=1.0 ):
image = self.transform(image).to(self.device)
image = self.transform(image).to(self.device).to(self.dtype)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
answer = model( image=[image], sampling_temperature=temperature )
answer = [ "".join(answer) ]
answer = self.model( image=[image], sampling_temperature=temperature )
answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
return answer

9
image_classifier/models/__init__.py Executable file → Normal file
View File

@ -1,18 +1,19 @@
from .base import Model
def get_model(cfg):
def get_model(cfg, training=False):
name = cfg.name
model = Model(
n_tokens=cfg.tokens,
n_len=cfg.len,
d_model=cfg.dim,
d_resnet=cfg.resnet,
)
model._cfg = cfg
model.config = cfg
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
def get_models(models):
return { model.full_name: get_model(model) for model in models }
def get_models(models, training=False):
return { model.full_name: get_model(model, training=training) for model in models }

View File

@ -12,7 +12,7 @@ from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from torchvision.models import resnet18
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from ..data import get_symmap
@ -20,12 +20,12 @@ class Model(nn.Module):
def __init__(
self,
n_tokens: int = 0, # number of token types
n_len: int = 6, # how long a sequence can be
n_len: int = 12, # how long a sequence can be
d_model: int = 512,
d_resnet: int = 18,
):
super().__init__()
_symmap = get_symmap()
self.symmap = { f'{v}': k for k, v in _symmap.items() }
self.symmap['0'] = ""
@ -36,8 +36,26 @@ class Model(nn.Module):
self.n_tokens = n_tokens
self.n_len = n_len + 2 # start/stop tokens
self.d_model = d_model
self.d_resnet = d_resnet
self.resnet = resnet18(pretrained=False)
ResNet = resnet18
if d_resnet == 18:
print("Using resnet18")
ResNet = resnet18
elif d_resnet == 34:
print("Using resnet34")
ResNet = resnet34
elif d_resnet == 50:
print("Using resnet50")
ResNet = resnet50
elif d_resnet == 101:
print("Using resnet101")
ResNet = resnet101
elif d_resnet == 152:
print("Using resnet152")
ResNet = resnet152
self.resnet = ResNet(pretrained=False)
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
self.accuracy_metric = MulticlassAccuracy(
@ -61,33 +79,29 @@ class Model(nn.Module):
sampling_temperature: float = 1.0,
):
x_list = torch.stack( image, dim=0 )
logits = self.resnet( torch.stack( image, dim=0 ) )
logits = logits.view(logits.size(0), self.n_len, self.n_tokens).permute(1, 0, 2)
x = self.resnet( x_list )
y = x.view(x.size(0), self.n_len, self.n_tokens)
# either of these should do, but my VALL-E forward pass uses this, so might as well keep to it
# pred = y.argmax(dim=2)
pred = Categorical(logits=y / sampling_temperature).sample()
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
pred = logits.argmax(dim=2)
if text is not None:
y_list = rearrange(pad_sequence(text), "t b -> b t")
loss = 0
labels = rearrange(pad_sequence(text), "t b -> b t").permute(1, 0)
loss = []
for i in range(self.n_len):
if i >= y_list.shape[1]:
if i >= labels.shape[0]:
break
loss += F.cross_entropy( y[:, i], y_list[:, i] )
loss.append( F.cross_entropy(logits[i], labels[i]) )
self.loss = dict(
nll=loss
nll = sum( loss ) / len( loss ),
)
self.stats = dict(
acc = self.accuracy_metric( pred, y_list ),
precision = self.precision_metric( pred, y_list ),
acc = self.accuracy_metric( pred, labels ),
precision = self.precision_metric( pred, labels ),
)
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
return answer

View File

@ -0,0 +1,214 @@
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
from functools import partial
import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from transformers.pytorch_utils import Conv1D
from torch import Tensor, nn
import math
from typing import Optional, List
from ..utils import passes_policy
# LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
class LoRALinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
rank: int = 4,
alpha: int = 1,
dropout: float = 0.1,
merge_weights: bool = False,
**kwargs,
):
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
self.rank = rank
self.alpha = alpha
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.merge_weights = merge_weights
self.merged = False
self.enabled = True
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
self.scaling = self.alpha / self.rank
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
super().reset_parameters()
# super silly but necessary because nn.Linear's constructor calls this
if hasattr(self, 'lora_A'):
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
nn.init.zeros_( self.lora_B )
def train(self, mode: bool = True):
super().train(mode)
# training, separate lora from base weights
if mode and self.merge_weights and self.merged:
self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
self.merged = False
# not training, merge lora to base weights
if not mode and self.merge_weights and not self.merged:
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if not self.merged and self.enabled:
result = F.linear(x, self.weight, bias=self.bias)
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
return F.linear(x, self.weight, bias=self.bias)
@classmethod
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
# Uses parametrization to inject LoRA weights
# Pros: should work with any Linears
# Cons: TBD
class ParameterizedLoRA(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
rank: int = 4,
alpha: int = 1,
dropout: float = 0.1,
device = None,
dtype = None
):
super().__init__()
self.rank = rank
self.alpha = alpha
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
self.scaling = self.alpha / self.rank
self.enabled = True
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
nn.init.zeros_( self.lora_B )
def forward(self, x: torch.Tensor):
if self.enabled:
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
return x
@classmethod
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
# swap because we're feeding the output as our input
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
@classmethod
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
in_channels, out_channels = layer.weight.shape
# swap because we're feeding the output as our input
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
for *parent, k in modules:
name = '.'.join(parent)
layer = getattr( model.get_submodule(name), k )
if isinstance( layer, nn.Linear ):
target = nn.Linear
klass = ParameterizedLoRA if use_parametrize else LoRALinear
replacer = klass.from_linear
elif isinstance( layer, nn.Conv1d ):
target = nn.Conv1d
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
replacer = klass.from_conv1d
elif isinstance( layer, Conv1D ):
target = Conv1D
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
replacer = klass.from_conv1d
else:
continue
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
if use_parametrize:
parametrize.register_parametrization( layer, "weight", replacement )
else:
setattr( model.get_submodule(name), k, replacement )
return enable_lora( model )
def enable_lora( model, mode = True ):
for name, module in model.named_modules():
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
continue
module.enabled = mode
return model
def disable_lora( model ):
return enable_lora( model, False )
def freeze_non_lora_weights( model, embeddings = False ):
frozen_params = []
for name, param in model.named_parameters():
should = 'lora_' in name or (embeddings and "_emb" in name)
param.requires_grad_(should)
if not should:
frozen_params.append( param )
return frozen_params
def lora_get_state_dict( state_dict, split = True ):
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
if not split:
return lora
return lora, { name: param for name, param in state_dict.items() if "lora_" not in name }
def lora_load_state_dict( model, state_dict ):
return model.load_state_dict( state_dict, strict = False )

120
image_classifier/plot.py Normal file
View File

@ -0,0 +1,120 @@
#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
from .config import cfg
def plot(paths, args):
dfs = []
for path in paths:
with open(path, "r") as f:
text = f.read()
rows = []
pattern = r"(\{.+?\})\.\n"
for row in re.findall(pattern, text, re.DOTALL):
try:
row = json.loads(row)
except Exception as e:
continue
for model in args.models:
if f'{model.name}.{args.xs}' not in row:
continue
rows.append(row)
break
df = pd.DataFrame(rows)
if "name" in df:
df["name"] = df["name"].fillna("train")
else:
df["name"] = "train"
df["group"] = str(path.parents[args.group_level])
df["group"] = df["group"] + "/" + df["name"]
dfs.append(df)
df = pd.concat(dfs)
if args.min_x is not None:
for model in args.models:
df = df[args.min_x < df[f'{model.name}.{args.xs}']]
if args.max_x is not None:
for model in args.models:
df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
for gtag, gdf in sorted(
df.groupby("group"),
key=lambda p: (p[0].split("/")[-1], p[0]),
):
for model in args.models:
x = f'{model.name}.{args.xs}'
for ys in args.ys:
y = f'{model.name}.{ys}'
if gdf[y].isna().all():
continue
if args.min_y is not None:
gdf = gdf[args.min_y < gdf[y]]
if args.max_y is not None:
gdf = gdf[gdf[y] < args.max_y]
gdf[y] = gdf[y].ewm(10).mean()
gdf.plot(
x=x,
y=y,
label=f"{y}",
ax=plt.gca(),
marker="x" if len(gdf) < 100 else None,
alpha=0.7,
)
plt.gca().legend(
loc="center left",
fancybox=True,
shadow=True,
bbox_to_anchor=(1.04, 0.5),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--xs", default="engine_step")
parser.add_argument("--ys", nargs="+", default="")
parser.add_argument("--model", nargs="+", default="*")
parser.add_argument("--min-x", type=float, default=-float("inf"))
parser.add_argument("--min-y", type=float, default=-float("inf"))
parser.add_argument("--max-x", type=float, default=float("inf"))
parser.add_argument("--max-y", type=float, default=float("inf"))
parser.add_argument("--filename", default="log.txt")
parser.add_argument("--group-level", default=1)
args, unknown = parser.parse_known_args()
path = cfg.rel_path / "logs"
paths = path.rglob(f"./*/{args.filename}")
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
if args.ys == "":
args.ys = ["loss"]
plot(paths, args)
out_path = cfg.rel_path / "metrics.png"
plt.savefig(out_path, bbox_inches="tight")

View File

@ -0,0 +1,204 @@
import math
import torch
import torch.nn.functional as F
import numpy as np
from torch import Tensor, einsum, nn
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
if factor == 1.0 or previous is None:
return logits
unique = set()
priors = reversed(previous)
for distance, token in enumerate(priors):
# skip if we're only applying the decay once
if one_time and token in unique:
continue
distance += 1
logits[:, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
unique.add(token)
return logits
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
# `length` is the length of the sequence currently
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
# `token` is the stop token.
def length_penalize( logits, length, factor=0.0, token=-1 ):
if factor == 0.0:
return logits
logits[:, token] /= (length ** factor)
return logits
# Simple way to ban tokens
def ban_tokens( logits, tokens ):
for token in tokens:
# token not in logits
if logits.shape[-1] >= token:
continue
logits[:, token] = -float("inf")
return logits
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens per batch example in the output
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens > 1:
# Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
# credit to https://github.com/LostRuins/koboldcpp/pull/464 // https://github.com/kalomaze/koboldcpp/tree/dynamic-temp
def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.0, k = 10, sigmoidCenterPoint = 0.5 ):
# loop over logits[:], as the NAR will have logits.shape[0] > 1
for i in range(logits.shape[0]):
sum_exp = 0.0
maximum = torch.max( logits[i] )
for logit in logits[i]:
sum_exp += math.exp( logit - maximum )
prob_max_token_before_temp = 1.0 / sum_exp
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
logits[i] /= dynamic_temperature
return logits
# picks the top K tokens amongst a batch of logits
# logits: [Tensor] list of logits
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
def top_k_logits_list( logits_list, k ):
# ( batch, tokens ) => ( batch x tokens )
logits = torch.cat( logits_list )
candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
for i, index in enumerate(candidates):
t = []
N = np.prod(logits.size())
for n in logits.size():
N //= n
t.append(index // N)
index %= N
candidates[i] = tuple(t)
return candidates
# Credit to: https://github.com/basusourya/mirostat/
# performs mirostat-based sampling
# logits: Tensor of logit probabilities
# state: the mirostat state
def mirostat_sample( logits, state = None ):
def compute_k(prob, n, tau):
num = 0
den = 0
for i in range(100):
b = prob[i]/prob[i+1]
t = (i+2)/(i+1)
num += math.log(b)*math.log(t)
den += math.log(t)**2
s = num/den
eps = s-1
k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s)
k = round(k)
return k
if "max_surprise" not in state:
state["max_surprise"] = state["tau"] * 2
if "error_surprise" not in state:
state["error_surprise"] = 0
if "running_total_surprise" not in state:
state["running_total_surprise"] = 0
sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True )
prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist()
k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1
sorted_logits = sorted_logits[0:k]
sorted_indices = sorted_indices[0:k]
prob_topk = torch.softmax(sorted_logits, dim = 0)
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
state["index_surprise"] = math.log2(1/prob_original[prev_i])
state["running_total_surprise"] += state["index_surprise"]
state["error_surprise"] = state["index_surprise"] - state["tau"]
state["max_surprise"] -= state["eta"] * state["error_surprise"]
state["token"] = sorted_indices[prev_i]
return state
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
# performs DRY sampling
# * (honestly it looks close to rep pen anyways but what do I know)
# `logits` are the scores used to sample against
# `previous` are the prior tokens to penalize with
# `factor` is the scalar multiplier
# `base` is the base number to raise to the (length - allowed_length)th power
# `allowed_length` limits the range to apply DRY to
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
if factor == 0.0 or previous is None:
return logits
lengths = {}
for i, token in enumerate( previous ):
length = 1
while length < max(allowed_length, 50):
j = i - length
# Start of input reached.
if j < 0:
break
# Start of match reached.
if previous[j] != previous[-length-1]:
break
length += 1
lengths[token] = max(length, lengths[token]) if token in lengths else length
for token, length in lengths.items():
if length < allowed_length:
break
logits[:, token] -= factor * base ** (length - allowed_length)
return logits

View File

@ -4,7 +4,7 @@ from .config import cfg
from .data import create_train_val_dataloader
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils.trainer import load_engines
from .utils.distributed import is_global_leader
import json
import logging
@ -12,14 +12,22 @@ import random
import torch
import torch.nn.functional as F
import traceback
import shutil
from collections import defaultdict
from PIL import Image
from tqdm import tqdm
import argparse
from PIL import Image, ImageDraw
_logger = logging.getLogger(__name__)
def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
engine( image=batch["image"], text=batch["text"] )
losses = engine.gather_attribute("loss")
@ -31,34 +39,16 @@ def train_feeder(engine, batch):
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
return loss, stats
@torch.inference_mode()
def run_eval(engines, eval_name, dl):
engines_stats = {
'eval': eval_name
}
model = None
names = []
for name, engine in engines.items():
names.append(name)
model = engine
break
stats = defaultdict(list)
stats['loss'] = []
def process( name, batch, resps_list ):
for path, ref, hyp in zip(batch["path"], batch["text"], hyp):
continue
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
def process( name, batch, res, loss ):
for path, ref, hyp in zip(batch["path"], batch["text"], res):
hyp = hyp.replace('<s>', "").replace("</s>", "")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png")
@ -67,36 +57,74 @@ def run_eval(engines, eval_name, dl):
image = Image.open(path).convert('RGB')
image.save(hyp_path)
stats['loss'].append(loss)
processed = 0
while processed < cfg.evaluation.size:
batch = to_device(next(iter(dl)), cfg.device)
# limit to eval batch size in the event we somehow have a weird dataloader
for key in batch.keys():
batch[key] = batch[key][:cfg.evaluation.batch_size]
processed += len(batch["text"])
for name in engines:
engine = engines[name]
res = engine( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum().item()
stats['loss'].append(loss)
process( name, batch, res, loss )
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats }))
iteration = engines.global_step
engines_stats['it'] = iteration
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
engines_stats = {
f'{name}.{eval_name}': stats,
"it": engines.global_step,
}
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def main():
def train():
parser = argparse.ArgumentParser("ResNet Image Classifier")
parser.add_argument("--eval", action="store_true", default=None)
args, unknown = parser.parse_known_args()
# create log folder
setup_logging(cfg.log_dir)
# copy config yaml to backup
if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def eval_fn(engines):
do_gc()
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
try:
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
except Exception as e:
print("Error occurred while performing eval:", str(e))
print(traceback.format_exc())
_logger.warning(f"Error occurred while performing eval: {str(e)}")
_logger.warning(traceback.format_exc())
engines.train()
do_gc()
if args.eval:
return eval_fn(engines=trainer.load_engines())
"""
if cfg.trainer.load_webui:
from .webui import start
start(lock=False)
"""
trainer.train(
train_dl=train_dl,
train_feeder=train_feeder,
@ -104,4 +132,5 @@ def main():
)
if __name__ == "__main__":
main()
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
train()

View File

@ -7,4 +7,7 @@ from .utils import (
to_device,
tree_map,
do_gc,
set_seed,
passes_policy,
get_devices
)

View File

@ -8,6 +8,10 @@ import socket
from functools import cache, wraps
from typing import Callable
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def get_free_port():
sock = socket.socket()
sock.bind(("", 0))
@ -15,13 +19,18 @@ def get_free_port():
_distributed_initialized = False
def init_distributed( fn ):
fn()
def init_distributed( fn, *args, **kwargs ):
torch.cuda.set_device(local_rank())
fn(*args, **kwargs)
_distributed_initialized = True
def distributed_initialized():
return _distributed_initialized
def cleanup_distributed():
dist.barrier()
dist.destroy_process_group()
@cache
def fix_unset_envs():
envs = dict(
@ -44,10 +53,12 @@ def fix_unset_envs():
def local_rank():
return int(os.getenv("LOCAL_RANK", 0))
def global_rank():
return int(os.getenv("RANK", 0))
def world_size():
return int(os.getenv("WORLD_SIZE", 1))
def is_local_leader():
return local_rank() == 0
@ -87,3 +98,6 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
return wrapper
return wrapper(fn)
def ddp_model(model):
return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)

View File

@ -0,0 +1,88 @@
import torch
import json
from pathlib import Path
from safetensors import safe_open as sft_load
from safetensors.torch import save_file as sft_save
def coerce_path( path ):
return path if isinstance( path, Path ) else Path(path)
def pick_path( path, *suffixes ):
suffixes = [*suffixes]
for suffix in suffixes:
p = path.with_suffix( suffix )
if p.exists():
return p
return path
def is_dict_of( d, t ):
if not isinstance( d, dict ):
return False
return all([ isinstance(v, torch.Tensor) for v in d.values() ])
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
metadata = None
# is a state_dict, no need to coerce
if is_dict_of( data, torch.Tensor ):
return data, metadata
# is maybe a dict with a state dict + metadata, coerce it
metadata = {}
target = module_key
if not target:
for k, v in data.items():
# is a dict of tensors, our target
if is_dict_of( v, torch.Tensor ):
target = k
continue # continue to iterate to grab other metadata
# not a dict of tensors, put it as metadata
try:
metadata[k] = json.dumps(v)
except Exception as e:
pass
if not target:
raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
return data[target], metadata
def torch_save( data, path, module_key=None ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
return sft_save( data, path, metadata )
return torch.save( data, path )
def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
state_dict = {}
with sft_load(path, framework=framework, device=device) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if load_metadata:
metadata = f.metadata()
for k, v in metadata.items():
try:
metadata[k] = json.loads( v )
except Exception as e:
pass
state_dict = { module_key: state_dict } | metadata
return state_dict
return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )

188
image_classifier/utils/sampler.py Executable file → Normal file
View File

@ -1,48 +1,164 @@
"""
A sampler that balances data by key_fns.
MIT License
Copyright (c) 2023 Zhe Niu
niuzhe.nz@outlook.com
"""
from dataclasses import dataclass
from typing import Any
import random
import torch
from torch.utils.data import Sampler
class Sampler:
def __init__(self, l, key_fns):
self.tree = self._build(l, key_fns)
from .distributed import global_rank, local_rank, world_size
def _build(self, l, key_fns) -> dict[dict, list]:
if not key_fns:
return l
# Randomly picks an index from an array of indices
class PoolSampler():
def __init__( self, pool = [], keep_all = False, shuffle = False ):
self.length = len(pool)
self.shuffle = shuffle
self.global_pool = pool if keep_all else None
self.global_indices = [ i for i in range(self.length) ]
self.reset()
tree = {}
def reset(self):
self.current_pool = [ i for i in self.global_indices ]
if self.shuffle:
random.shuffle(self.current_pool)
key_fn, *key_fns = key_fns
def sample(self, pool = None):
if pool is None:
pool = self.global_pool
# check if we need to reset
index = random.choice( self.current_pool )
# remove from pool
self.current_pool.remove(index)
# reset if needed
if len(self.current_pool) == 0:
self.reset()
# map indices to our real values
return pool[index] if pool is not None else index
for x in l:
k = key_fn(x)
def __len__(self):
return self.length # len(self.current_pool)
if k in tree:
tree[k].append(x)
else:
tree[k] = [x]
def __iter__(self):
while len(self.current_pool) > 0:
yield self.sample()
for k in tree:
tree[k] = self._build(tree[k], key_fns)
def __call__(self, *args, **kwargs):
return self.sample(*args, **kwargs)
return tree
def get_state(self):
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
def _sample(self, tree: dict | list):
if isinstance(tree, list):
ret = random.choice(tree)
else:
key = random.choice([*tree.keys()])
ret = self._sample(tree[key])
return ret
def set_state(self, state):
self.length = state["length"]
self.global_pool = state["global_pool"]
self.global_indices = state["global_indices"]
self.current_pool = state["current_pool"]
def sample(self):
return self._sample(self.tree)
# "Samples" through a fixed sequence from 0 to length
# Necessary for our "shuffle+sort by duration+interleave" sampling method
# Allows saving and loading state
class OrderedSampler(Sampler):
def __init__( self, length ):
self.position = 0
self.length = length
def __len__(self):
return self.length
def __iter__(self):
if self.position >= self.length:
self.position = 0
while self.position < self.length:
yield self.position
self.position += 1
def get_state(self):
return { "position": self.position, "length": self.length }
def set_state(self, state):
self.position = state["position"]
self.length = state["length"]
# Like the above, but will batch based on token count
class BatchedOrderedSampler(Sampler):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
self.position = 0
self.batches = []
self.shuffle = shuffle
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
current_batch = []
current_size = 0
current_index = 0
for key, bucket in buckets.items():
for path, duration in bucket:
# flush
should_flush = False
if max_duration > 0 and current_size + duration > max_duration:
should_flush = True
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
should_flush = True
if should_flush and len(current_batch) > 0:
self.batches.append( current_batch )
current_batch = []
current_size = 0
current_batch.append( current_index )
current_index += 1
current_size += duration
if self.shuffle:
random.shuffle(self.batches)
def __len__(self):
return len(self.batches)
def __iter__(self):
if self.position >= len(self.batches):
self.position = 0
if self.shuffle:
random.shuffle(self.batches)
while self.position < len(self.batches):
yield self.batches[self.position]
self.position += 1
def get_state(self):
return { "position": self.position, "batches": self.batches }
def set_state(self, state):
self.position = state["position"]
self.batches = state["batches"]
# Randomly samples indices from a given sequence from 0 to length
# Allows saving and loading state
class RandomSampler(Sampler):
def __init__( self, length ):
self.position = 0
self.length = length
self.generator = torch.Generator()
self.perm = torch.randperm(self.length, generator=self.generator)
def __len__(self):
return self.length
def __iter__(self):
if self.position >= self.length:
self.position = 0
self.perm = torch.randperm(self.length, generator=self.generator)
while self.position < self.length:
yield self.perm[self.position]
self.position += 1
def get_state(self):
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
def set_state(self, state):
self.position = state["position"]
self.length = state["length"]
self.perm = state["perm"]
self.generator.set_state(state["generator"])

View File

@ -4,12 +4,13 @@
import humanize
import json
import os
import logging
import numpy as np
import random
import selectors
import sys
import torch
import os
from functools import cache
from torch.distributed import broadcast_object_list
@ -18,9 +19,10 @@ from tqdm import tqdm
from typing import Protocol
from ..config import cfg
from .distributed import init_distributed, distributed_initialized
from .distributed import (
fix_unset_envs,
init_distributed,
distributed_initialized,
world_size,
global_leader_only,
global_rank,
is_global_leader,
@ -28,73 +30,15 @@ from .distributed import (
local_leader_only,
)
from ..engines import Engine, Engines, TrainFeeder, default_feeder
from ..models import get_models
from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
from .utils import to_device, do_gc
from .utils import to_device, do_gc, truncate_json
from ..utils import wrapper as ml
from ..data import get_symmap # should decouple from this trainer script
_logger = logging.getLogger(__name__)
_engines: Engines
_command: str
def get_global_step():
try:
return _engines.global_step
except:
return None
def get_micro_step():
try:
return _engines.micro_step
except:
return None
def get_cmd():
try:
return _command
except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
get_iteration = get_global_step
def load_engines():
models = get_models(cfg.models.get())
engines = dict()
for name in models:
model = models[name]
optimizer = None
lr_scheduler = None
if cfg.hyperparameters.optimizer.lower() == "adamw":
optimizer = ml.AdamW(
model.parameters(),
lr=cfg.hyperparameters.learning_rate,
betas=(0.9, 0.96),
eps=1e-07,
weight_decay=0.01,
)
if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_path))
engines[name] = Engine(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
engines = Engines(engines)
engines.setup()
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
return engines
class EvalFn(Protocol):
def __call__(self, *, engines: Engines):
@ -151,17 +95,16 @@ def _non_blocking_input():
l[0] = s
if distributed_initialized():
if world_size() > 1:
broadcast_object_list(l, src=0)
_command = l[0]
return _command
def _make_infinite_epochs(dl):
while True:
_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
#_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
@local_leader_only(default=None)
@ -172,30 +115,32 @@ def logger(data):
def seed(seed):
# Set up random seeds, after fork()
random.seed(seed + global_rank())
#np.random.seed(seed + global_rank())
np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank())
def train(
train_dl: DataLoader,
train_feeder: TrainFeeder = default_feeder,
eval_fn: EvalFn = lambda x: ...,
logger: Logger = logger,
):
fix_unset_envs()
engines = load_engines()
# validate if there's at least one model to train
found = False
for name, engine in engines.items():
if engine.training:
found = True
break
if not found:
raise Exception('Training, but no model loaded set to train...')
"""
if is_local_leader():
cfg.dump()
_logger.info(cfg)
"""
# Setup global engines
global _engines
_engines = engines
events = []
eval_fn = global_leader_only(eval_fn)
@ -203,15 +148,20 @@ def train(
# Pre-loop command
command = _non_blocking_input()
if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]:
engines.quit()
return
last_save_step = engines.global_step
last_eval_step = 0
"""
if cfg.distributed:
train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths)))
"""
# Training loop
for batch in _make_infinite_epochs(train_dl):
if engines.global_step >= cfg.trainer.iterations:
@ -219,17 +169,15 @@ def train(
#batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(batch=batch, feeder=train_feeder)
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
stats['it'] = iteration
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
del stats['batch_size']
del stats['wall_time']
del stats['global_step']
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
try:
metrics = json.dumps(stats)
except Exception as e:
metrics = str(stats)
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
command = _non_blocking_input()
@ -267,29 +215,48 @@ def train(
if "lr" in command:
rate = float(command.split(" ")[-1])
try:
engines.set_lr(rate)
print("Updating LR to:", rate)
_logger.info(f"Updating LR to: {rate}")
except Exception as e:
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
if "export" in command:
train_dl.dataset.save_state_dict()
engines.save_checkpoint()
last_save_step = engines.global_step
if is_global_leader():
engines.export(userdata={"symmap": get_symmap()})
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"]
export_commands = ["export"]
if cfg.trainer.save_on_quit:
saving_commands.append("quit")
if cfg.trainer.export_on_quit:
export_commands.append("quit")
if cfg.trainer.export_on_save:
export_commands.append("save")
if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
train_dl.dataset.save_state_dict()
engines.save_checkpoint()
last_save_step = engines.global_step
if command in export_commands and is_global_leader():
engines.export(userdata={"symmap": get_symmap()})
if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
do_gc()
engines.eval()
eval_fn(engines=engines)
engines.train()
last_eval_step = engines.global_step
eval_fn(engines=engines)
if command in ["quit"]:
engines.quit()
return

View File

@ -7,8 +7,16 @@ from .distributed import global_rank, local_rank, global_leader_only
import gc
import logging
import pandas as pd
import numpy as np
import re
import torch
import random
import time
import psutil
import math
import logging
_logger = logging.getLogger(__name__)
from coloredlogs import ColoredFormatter
from logging import StreamHandler
@ -16,9 +24,16 @@ from pathlib import Path
from torch import Tensor, nn
from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload
from contextlib import contextmanager
T = TypeVar("T")
def truncate_json( str ):
def fun( match ):
return "{:.4f}".format(float(match.group()))
return re.sub(r"\d+\.\d{8,}", fun, str)
def do_gc():
gc.collect()
torch.cuda.empty_cache()
@ -28,6 +43,14 @@ def flatten_dict(d):
return records[0] if records else {}
def set_seed(seed=None):
if not seed:
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def _get_named_modules(module, attrname):
for name, module in module.named_modules():
if hasattr(module, attrname):
@ -155,5 +178,363 @@ def tree_map(fn: Callable, x):
return x
def to_device(x: T, device) -> T:
return tree_map(lambda t: t.to(device), x)
def to_device(x: T | None, *args, **kwargs) -> T:
if x is None:
return
return tree_map(lambda t: t.to(*args, **kwargs), x)
def coalese( *arg, return_last=True ):
return [ x for x in arg if x is not None ][-1 if return_last else 0]
# checks if a module name is within a given whitelist/blacklist policy dict
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
return False
if "include" in policy:
for term in policy["include"]:
if term in name:
return True
return False
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
return wrapper
# handles migrating an input tensor to a given devicve
def auto_align_inputs_forward( module, device=None, name = None ):
func = module.forward
if device is None:
if hasattr( module, 'device' ):
device = module.device
else:
try:
device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
except Exception as e:
return func
def wrapper( *args, **kwargs ):
args = [*args]
# search through args and kwargs for any Tensor arguments
for i, arg in enumerate(args):
if not isinstance( arg, torch.Tensor ):
continue
args[i] = arg.to( device=device )
for k, v in kwargs.items():
if not isinstance( v, torch.Tensor ):
continue
kwargs[k] = v.to( device=device )
# disgusting patch
if "position_embeddings" in kwargs:
kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
return func( *args, **kwargs )
return wrapper
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
# generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
in_features = m.in_features,
out_features = m.out_features,
bias = m.bias is not None,
) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, mode="math", verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
mode = mode,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
# trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target, dim=0, random=True ):
# trim
if target < weight.shape[dim]:
return weight[:target]
# expand
if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack(
[ x for x in weight ] +
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
)
return weight
def get_devices():
return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu']
# grabs the memory properties of a given device
def get_device_properties( device ):
if 'cuda' in device:
props = torch.cuda.get_device_properties(device)
free, total = torch.cuda.mem_get_info(device)
else:
props = psutil.virtual_memory()
free, total = props.available, props.total
return {"name": device, "total": total, "free": free, "props": props}
# gets the rough size for a given module's parameters
def get_module_size( module ):
param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
return param_size + buffer_size
# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way
# it'd be better to just attach to layers itself rather than every single module
# assigns modules to requested devices for a given policy
def get_model_offload_policy(module, policy=None):
# handle any other weird values this is set to
if not isinstance(policy, dict):
policy = {}
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
if "include" not in policy:
policy["include"] = ["model"]
if "limits" not in policy:
policy["limits"] = []
if "assign" not in policy:
policy["assign"] = []
if "devices" not in policy:
policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget
# create initial device info
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
# filter
modules = [ (name, size) for name, size in modules if name and size ]
total_size = sum([size for name, size in modules])
# set caps if requested in the policy
for i, cap in enumerate(policy["limits"]):
# no limit, skip
if cap <= 0:
continue
# is fractional, scale to total size
if cap < 1:
cap = math.floor(total_size * cap)
# available space is below cap, don't set
if devices[i]["free"] < cap:
continue
# cap to requested size
devices[i]["free"] = cap
# assign if specific parts of the model are requested for assignment
if policy["assign"]:
discarded = []
# yuck, there has to be a better way
for device_index, includes in enumerate( policy["assign"] ):
device = devices[device_index]
buffered_modules = []
buffered_size = device["free"]
# iterate through list of modules to compare against includes
for name, size in modules:
# doesn't pass policy
if not passes_policy( {"include": includes}, name ):
continue
# check if within budget
if buffered_size - size >= 0:
# add to buffer
buffered_modules.append( (name, size) )
buffered_size -= size
# budget exceeded, flush buffer
else:
discarded += buffered_modules
buffered_modules = []
buffered_size = 0
break
if buffered_modules and buffered_size:
device["modules"] += [ name for name, size in buffered_modules ]
device["free"] = buffered_size
modules = discarded
device_index = 0
module_index = 0
# assign modules to each device
while module_index < len(modules) and device_index < len(devices):
device = devices[device_index]
name, size = modules[module_index]
# fits within budget
if device["free"] - size >= 0:
device["modules"].append( name )
device["free"] -= size
module_index += 1
# does not fit in budget, increase device index
else:
device_index += 1
_logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
# to-do: check that all modules are exhausted
assert module_index >= len(modules)
# only return devices with modules assigned
return [ device for device in devices if device["modules"] ]
# handles naively splitting a model's layers across multiple devices
# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
def offload_model( model, policy=None ):
policy = get_model_offload_policy(model, policy=policy)
# move modules to respective devices
for i, device in enumerate( policy ):
# nothing assigned, skip
if not device["modules"]:
continue
for name in device["modules"]:
module = model.get_submodule(name)
module = module.to( device["name"] )
module.device = device['name']
# wrap modules with forward to ensure all inputs are matched to its device
for name, module in model.named_modules():
if not hasattr( module, 'forward' ):
continue
module.forward = auto_align_inputs_forward(module)
"""
# Validate that the layers are all in the right spot
for name, module in model.named_modules():
if not not [*module.named_children()]:
continue
try:
_logger.info( name, next(module.parameters()).device )
except Exception as e:
_logger.info( name, "?" )
pass
"""
return model

View File

@ -1,20 +1,39 @@
from contextlib import contextmanager
import math
import torch
import torch.nn.functional as F
import logging
from ..config import cfg
_logger = logging.getLogger(__name__)
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
if cfg.bitsandbytes.enabled:
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
Adagrad = torch.optim.Adagrad
# https://github.com/kyegomez/BitNet
if cfg.optimizations.bitnet:
from bitnet import BitLinear
if cfg.optimizations.bitsandbytes:
import bitsandbytes as bnb
if cfg.bitsandbytes.linear:
if cfg.optimizations.linear:
if cfg.optimizations.bitnet:
Linear = BitLinear
else:
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.embedding:
Embedding = bnb.nn.StableEmbedding
if cfg.optimizations.embedding:
Embedding = bnb.nn.modules.Embedding
"""
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input,
self.weight,
@ -24,52 +43,101 @@ if cfg.bitsandbytes.enabled:
self.scale_grad_by_freq,
self.sparse,
)).to(self.weight.dtype) )
"""
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
if cfg.optimizations.optimizers:
Adam = bnb.optim.Adam8bit
AdamW = bnb.optim.AdamW8bit
SGD = bnb.optim.SGD8bit
Adagrad = bnb.optim.Adagrad8bit
if cfg.bitsandbytes.enabled:
import bitsandbytes as bnb
elif cfg.optimizations.dadaptation:
import dadaptation
Adam = bnb.optim.Adam
AdamW = bnb.optim.AdamW
if cfg.optimizations.optimizers:
Adam = dadaptation.DAdaptAdam
AdamW = dadaptation.DAdaptAdam
SGD = dadaptation.DAdaptSGD
AdaGrad = dadaptation.DAdaptAdaGrad
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
if cfg.optimizations.fp8:
import transformer_engine.pytorch as te
Linear = te.Linear
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
def autocast():
yield te.fp8_autocast(enabled=True)
else:
yield input
@contextmanager
def autocast():
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
"""
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
return func( self, input.to(torch.int32), *args, **kwargs )
return func( self, input, *args, **kwargs )
"""
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
if cfg.optimizations.injects:
if cfg.optimizations.linear:
torch.nn.Linear = Linear
if cfg.optimizations.embedding:
torch.nn.Embedding = Embedding
if cfg.optimizations.optimizers:
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
AVAILABLE_COMPILE_BACKENDS = []
try:
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
except Exception as e:
pass
if cfg.optimizations.tensorrt:
try:
import torch_tensorrt
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
except Exception as e:
_logger.warning(f'Error while importing TensorRT: {str(e)}')
pass
def compile_model(model, backend="auto"):
if not backend or backend == "auto":
backend = AVAILABLE_COMPILE_BACKENDS[0]
if backend not in AVAILABLE_COMPILE_BACKENDS:
return torch.compile(model)
return torch.compile(model, backend=backend)
# https://github.com/konstmish/prodigy
try:
from prodigyopt import Prodigy
except Exception as e:
_logger.warning(f'Error while importing Prodigyopt: {str(e)}')
pass
# https://github.com/facebookresearch/schedule_free/
try:
import schedulefree
except Exception as e:
_logger.warning(f'Error while importing Schedule_Free: {str(e)}')
pass
# backwards compat
from .utils import (
autocast_forward,
replace_linear as replace_linear_old,
replace_embedding as replace_embedding_old,
replace_attention,
resize_weight,
offload_model,
)
# wrapped here so we can maintain default args
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
return replace_linear_old( model, klass, target, verbose )
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
return replace_embedding_old( model, klass, target, verbose )
Embedding.forward = autocast_forward(Embedding.forward)

220
image_classifier/webui.py Normal file
View File

@ -0,0 +1,220 @@
import os
import re
import argparse
import random
import tempfile
import functools
from datetime import datetime
import gradio as gr
from time import perf_counter
from pathlib import Path
from PIL import Image
from .inference import Classifier, cfg
from .train import train
from .utils import get_devices
classifier = None
layout = {}
layout["inference"] = {}
layout["training"] = {}
layout["settings"] = {}
for k in layout.keys():
layout[k]["inputs"] = { "progress": None }
layout[k]["outputs"] = {}
layout[k]["buttons"] = {}
# there's got to be a better way to go about this
def gradio_wrapper(inputs):
def decorated(fun):
@functools.wraps(fun)
def wrapped_function(*args, **kwargs):
for i, key in enumerate(inputs):
kwargs[key] = args[i]
try:
return fun(**kwargs)
except Exception as e:
raise gr.Error(str(e))
return wrapped_function
return decorated
class timer:
def __init__(self, msg="Elapsed time:"):
self.msg = msg
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
gr.Info(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
# returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./data/"), Path("./training/"), Path("./models/")] ):
yamls = []
for path in paths:
if not path.exists():
continue
for yaml in path.glob("**/*.yaml"):
if "/logs/" in str(yaml):
continue
yamls.append( yaml )
return yamls
def get_dtypes():
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
def load_model( yaml, device, dtype ):
gr.Info(f"Loading: {yaml}")
try:
init_classifier( yaml=Path(yaml), restart=True, device=device, dtype=dtype )
except Exception as e:
raise gr.Error(e)
gr.Info(f"Loaded model")
def init_classifier(yaml=None, restart=False, device="cuda", dtype="auto"):
global classifier
if classifier is not None:
if not restart:
return classifier
del classifier
classifier = None
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default=device)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=dtype)
args, unknown = parser.parse_known_args()
classifier = Classifier( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return classifier
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.yaml_path:
raise Exception("No YAML loaded.")
parser = argparse.ArgumentParser(allow_abbrev=False)
# I'm very sure I can procedurally generate this list
parser.add_argument("--image", type=str, default=kwargs["image"])
parser.add_argument("--temp", type=float, default=kwargs["temp"])
args, unknown = parser.parse_known_args()
classifier = init_classifier()
args.image = Image.open(args.image).convert('RGB')
gr.Info("Inferencing...")
with timer("Inferenced in") as t:
answer = classifier.inference(
image=args.image,
temperature=args.temp,
)
return answer
# setup args
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true")
parser.add_argument("--render_markdown", action="store_true", default="CLASSIFIER_YAML" in os.environ)
args, unknown = parser.parse_known_args()
args.listen_host = None
args.listen_port = None
args.listen_path = None
if args.listen:
try:
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
args.listen_port = match[1] if match[1] != "" else None
args.listen_path = match[2] if match[2] != "" else "/"
except Exception as e:
pass
if args.listen_port is not None:
args.listen_port = int(args.listen_port)
if args.listen_port == 0:
args.listen_port = None
# setup gradio
ui = gr.Blocks()
with ui:
with gr.Tab("Inference"):
with gr.Row():
with gr.Column(scale=4):
layout["inference"]["inputs"]["image"] = gr.Image(label="Input Image", sources=["upload"], type="filepath")
layout["inference"]["outputs"]["output"] = gr.Textbox(label="Output")
with gr.Column(scale=4):
with gr.Row():
layout["inference"]["inputs"]["temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature", info="Modifies the randomness from the samples. (0 to greedy sample)")
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
layout["inference"]["buttons"]["inference"].click(
fn=do_inference,
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
)
"""
with gr.Tab("Training"):
with gr.Row():
with gr.Column(scale=1):
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
with gr.Row():
with gr.Column(scale=1):
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
layout["training"]["buttons"]["train"].click(
fn=do_training,
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
)
"""
with gr.Tab("Settings"):
with gr.Row():
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
layout["settings"]["buttons"]["load"].click(
fn=load_model,
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
)
if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read()
# remove HF's metadata
if md.startswith("---\n"):
md = "".join(md.split("---")[2:])
gr.Markdown(md)
def start( lock=True ):
ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
if __name__ == "__main__":
start()

0
scripts/run.sh Executable file → Normal file
View File

View File

@ -1,5 +1,5 @@
import subprocess
import sys
from pathlib import Path
from datetime import datetime
from setuptools import setup, find_packages
@ -8,7 +8,6 @@ def shell(*args):
out = subprocess.check_output(args)
return out.decode("ascii").strip()
def write_version(version_core, pre_release=True):
if pre_release:
time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
@ -23,8 +22,7 @@ def write_version(version_core, pre_release=True):
return version
with open("README.md", "r", encoding="utf-8") as f:
with open("README.md", "r") as f:
long_description = f.read()
setup(
@ -37,17 +35,37 @@ setup(
long_description=long_description,
long_description_content_type="text/markdown",
packages=find_packages(),
install_requires=[
install_requires=(
# training backends
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else [])
+ [
# logging niceties
"coloredlogs>=15.0.1",
"humanize>=4.4.0",
"matplotlib>=3.6.0",
"pandas>=1.5.0",
# boiler plate niceties
"diskcache>=5.4.0",
"einops>=0.6.0",
"omegaconf==2.0.6",
"tqdm>=4.64.1",
"humanize>=4.4.0",
"tqdm",
"pandas>=1.5.0",
# HF bloat
"tokenizers",
"transformers",
"safetensors",
# training bloat
"h5py",
"prodigyopt @ git+https://github.com/konstmish/prodigy",
# practically the reason to use python
"numpy",
"torch>=1.13.0",
"torchmetrics",
"simple_http_server",
"pillow"
],
url="https://git.ecker.tech/mrq/resnet-classifier",
)