diff --git a/.gitignore b/.gitignore
old mode 100755
new mode 100644
diff --git a/LICENSE b/LICENSE
old mode 100755
new mode 100644
diff --git a/README.md b/README.md
old mode 100755
new mode 100644
index 347e629..a744721
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
# Tentative Title For A ResNet-Based Image Classifier
-This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
+This is a simple ResNet based image classifier for images, using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
## Premise
-This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
+This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I faced.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
@@ -16,44 +16,14 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
3. Install using `pip3 install -e ./image_classifier/`.
-4. Train using `python3 -m image_classifier.train yaml='./data/config.yaml'`.
+4. Train using `python3 -m image_classifier.train --yaml='./data/config.yaml'`.
5. Wait.
## Inferencing
-Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
+Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" --yaml="./data/config.yaml"`
### Continuous Usage
-If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
-
-## Known Issues
-
-* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
-* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
-* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
-
-## Strawmen
-
->\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!!
-
-I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests.
-
->\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]`
-
-Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader).
-
-Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there.
-
->\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`??
-
-I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment.
-
->\> UGH!!! What are you talking about """specific images"""???
-
-[ひみつ](https://files.catbox.moe/csuh49.webm)
-
->\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`???
-
-:)
\ No newline at end of file
+If you're looking to continuously classify images, use `python3 -m image_classifier --listen --port=7860 --yaml="./data/config.yaml"` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
\ No newline at end of file
diff --git a/data/config.yaml b/data/config.yaml
old mode 100755
new mode 100644
index 5e51278..3bebbe0
--- a/data/config.yaml
+++ b/data/config.yaml
@@ -1,85 +1,84 @@
-dataset:
- training: [
- "./data/images/"
- ]
-
- validation: []
-
- use_hdf5: False
-
- workers: 0
- cache: True
+weights_format: sft
models:
- _models:
- - name: "classifier"
- tokens: 0
- len: 6
+- name: "classifier"
+ tokens: 0
+ len: 6
+ dim: 512
+ resnet: 34
+#loras:
+#- name : "lora"
+# rank: 128
+# alpha: 128
+# training: True
+# rvq_levels: []
hyperparameters:
batch_size: 256
- gradient_accumulation_steps: 64
- gradient_clipping: 100
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ warmup_steps: 10
+
+ optimizer: Prodigy
+ learning_rate: 1.0
+ torch_optimizer: True
- optimizer: Adamw
- learning_rate: 1.0e-3
-
- scheduler_type: ""
- #scheduler_type: OneCycle
- #scheduler_params:
- # cycle_first_step_size: 10_000
- # cycle_first_stair_count: 10_000
-
- # cycle_second_step_size: 15_000
- # cycle_second_stair_count: 15_000
-
- # decay_step_size: 5_000
-
- # cycle_min_lr: 2.5e-4 # 1.0e-5
- # cycle_max_lr: 2.5e-4 # 1.0e-4
- # decay_lr_rate: 0.0
-
- # cycle_min_mom: 0.90
- # cycle_max_mom: 0.99
- # decay_mom_rate: 0.0
+ scheduler: "" # ScheduleFree
+ torch_scheduler: True
evaluation:
- batch_size: 32
- frequency: 250
- size: 32
+ batch_size: 64
+ frequency: 100
+ size: 64
- steps: 300
- temperature: 1.0
+ steps: 450
+ temperature: 0.0
trainer:
- iterations: 100_000
-
- save_tag: step
- save_on_oom: True
- save_on_quit: True
+ iterations: 1_000_000
save_frequency: 100
-
- aggressive_optimizations: False
-
- check_for_oom: False
-
- #load_tag: "9500"
- #load_state_dict: True
- #load_states: False
- #strict_loading: False
- #restart_step_count: True
+ keep_last_checkpoints: 32
- gc_mode: None # "global_step"
+ check_for_oom: False
+ gradient_checkpointing: True
- weight_dtype: float32
+ weight_dtype: bfloat16
+ amp: True
- backend: local
+ backend: deepspeed
deepspeed:
- zero_optimization_level: 0
- use_compression_training: True
+ inferencing: False
+ amp: False
inference:
- use_vocos: True
+ backend: local
-bitsandbytes:
- enabled: false
\ No newline at end of file
+ weight_dtype: bfloat16
+ amp: True
+
+optimizations:
+ injects: False
+ replace: True
+
+ linear: False
+ embedding: False
+ optimizers: True
+
+ bitsandbytes: False
+ dadaptation: False
+ bitnet: False
+ fp8: False
+
+dataset:
+ use_hdf5: True
+ hdf5_flag: r
+
+ workers: 1
+ cache: True
+
+ training: [
+ "./data/images/"
+ ]
+ validation: [
+ "./data/validation/"
+ ]
\ No newline at end of file
diff --git a/image_classifier/__main__.py b/image_classifier/__main__.py
index d9907fe..3916455 100755
--- a/image_classifier/__main__.py
+++ b/image_classifier/__main__.py
@@ -12,35 +12,55 @@ def main():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", action='store_true')
parser.add_argument("--port", type=int, default=9090)
+
parser.add_argument("--yaml", type=Path, default=None)
- parser.add_argument("--ckpt", type=Path, default=None)
- parser.add_argument("--temp", type=float, default=1.0)
- parser.add_argument("--device", default="cuda")
+ parser.add_argument("--device", type=str, default=None)
+ parser.add_argument("--amp", action="store_true")
+ parser.add_argument("--dtype", type=str, default=None)
+
+ parser.add_argument("--temp", type=float, default=0.0)
+
args, unknown = parser.parse_known_args()
- classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
+ classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
if args.listen:
@route("/")
- def inference( b64, temperature=1.0 ):
+ def inference( b64, temperature=args.temp ):
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
- return { "answer": classifier.inference( image=image, temperature=args.temp ) }
+ return { "answer": classifier.inference( image=image, temperature=temperature ) }
server.start(port=args.port)
else:
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--path", type=Path)
parser.add_argument("--base64", type=str)
+ parser.add_argument("--write", type=Path)
parser.add_argument("--temp", type=float, default=1.0)
- args, unknown = parser.parse_known_args()
+ args, unknown = parser.parse_known_args()
+ images = []
if args.path:
- image = Image.open(args.path).convert('RGB')
+ if args.path.is_dir():
+ for p in args.path.rglob("./*.jpg"):
+ image = Image.open(p).convert('RGB')
+ images.append(image)
+ for p in args.path.rglob("./*.png"):
+ image = Image.open(p).convert('RGB')
+ images.append(image)
+ else:
+ image = Image.open(args.path).convert('RGB')
+ images.append(image)
elif args.base64:
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
+ images.append(image)
else:
raise "Specify a --path or --base64."
- answer = classifier.inference( image=image, temperature=args.temp )
- print("Answer:", answer)
+ for image in images:
+ answer = classifier.inference( image=image, temperature=args.temp )
+ print("Answer:", answer)
+ if args.write:
+ args.write.mkdir(exist_ok=True)
+ image.save( args.write / f"{answer}.jpg")
if __name__ == "__main__":
main()
diff --git a/image_classifier/config.py b/image_classifier/config.py
index e39a22e..8adfef8 100755
--- a/image_classifier/config.py
+++ b/image_classifier/config.py
@@ -6,31 +6,61 @@ import os
import subprocess
import sys
import time
-
-from dataclasses import asdict, dataclass
-from dataclasses import dataclass, field
-
-from functools import cached_property, cache
-from pathlib import Path
-from omegaconf import OmegaConf
+import argparse
+import yaml
+import random
+import logging
import torch
+import numpy as np
+
+from dataclasses import asdict, dataclass, field
+
+from functools import cached_property
+from pathlib import Path
+
+from .utils.distributed import world_size
+
+
+def set_seed(seed=None):
+ if not seed:
+ seed = time.time()
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
@dataclass()
-class _Config:
- cfg_path: str | None = None
+class BaseConfig:
+ yaml_path: str | None = None # path passed in through --yaml
@property
- def relpath(self):
+ def cfg_path(self):
+ return Path(self.yaml_path.parent) if self.yaml_path is not None else None
+
+ @property
+ def rel_path(self):
return Path(self.cfg_path)
+ @property
+ def cache_dir(self):
+ return self.rel_path / ".cache"
+
+ @property
+ def data_dir(self):
+ return self.rel_path / "data"
+
+ @property
+ def metadata_dir(self):
+ return self.rel_path / "metadata"
+
@property
def ckpt_dir(self):
- return self.relpath / "ckpt"
+ return self.rel_path / "ckpt"
@property
def log_dir(self):
- return self.relpath / "logs" / str(self.start_time)
+ return self.rel_path / "logs" / str(self.start_time)
@cached_property
def start_time(self):
@@ -64,39 +94,28 @@ class _Config:
with open(path, "w") as f:
f.write(self.dumps())
- @staticmethod
- def _is_cfg_argv(s):
- return "=" in s and "--" not in s
-
@classmethod
def from_yaml( cls, yaml_path ):
- return cls.from_cli( [f'yaml="{yaml_path}"'] )
+ state = {}
+ state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
+ state.setdefault("yaml_path", yaml_path)
+ return cls(**state)
@classmethod
def from_cli(cls, args=sys.argv):
- cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)])
+ # legacy support for yaml=`` format
+ for i, arg in enumerate(args):
+ if arg.startswith("yaml"):
+ args[i] = f'--{arg}'
- # Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
- sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)]
+ parser = argparse.ArgumentParser(allow_abbrev=False)
+ parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
+ args, unknown = parser.parse_known_args(args=args)
- if cli_cfg.get("help"):
- print(f"Configurable hyperparameters with their default values:")
- print(json.dumps(asdict(cls()), indent=2, default=str))
- exit()
+ if args.yaml:
+ return cls.from_yaml( args.yaml )
- if "yaml" in cli_cfg:
- yaml_cfg = OmegaConf.load(cli_cfg.yaml)
- yaml_path = Path(cli_cfg.yaml).absolute()
- cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1])
- cfg_path = cfg_path.with_suffix("")
- cfg_path = f'./{cfg_path}'
-
- yaml_cfg.setdefault("cfg_path", cfg_path)
- cli_cfg.pop("yaml")
- else:
- yaml_cfg = {}
- merged = OmegaConf.merge(yaml_cfg, cli_cfg)
- return cls(**dict(merged))
+ return cls(**{})
def __repr__(self):
return str(self)
@@ -106,104 +125,195 @@ class _Config:
@dataclass()
class Dataset:
- training: list[Path] = field(default_factory=lambda: [])
- validation: list[Path] = field(default_factory=lambda: [])
-
- temp: list[Path] = field(default_factory=lambda: [])
-
- # de-implemented, because the data isn't that large to facilitate HDF5
- hdf5_name: str = "data.h5"
- use_hdf5: bool = False
+ training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
+ validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
- workers: int = 8
- cache: bool = True
+ hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
+ use_hdf5: bool = False # whether to load from an HDF5 dataset
+ hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
+
+ validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
+ workers: int = 8 # number of dataloader workers to spawn
+ cache: bool = True # use diskcache to cache the dataset
+# I really need to clean this up
@dataclass()
class Model:
- name: str = ""
+ name: str = "classifier"
tokens: int = 0 # number of token types
len: int = 1 # how long a sequence can be
dim: int = 512
+ resnet: int = 18
+
+ width: int = 300
+ height: int = 80
+
+ version: int = 1
+ training: bool = True
+ frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
@property
def full_name(self):
return self.name
-@dataclass()
-class Models:
- _models: list[Model] = field(default_factory=lambda: [
- Model(name="captcha"),
- ])
-
def get(self, name=None):
- if not name:
- return [ Model(**model) for model in self._models ]
+ return [ self ] if not name or self.name == name else []
+
+ def loss_factor(self, k):
+ return self.loss_factors[k] if k in self.loss_factors else 1.0
- for model in self._models:
- if model.name == name:
- return model
+ @property
+ # required for fp8 as the lengths needs to be divisible by 8
+ def input_alignment(self):
+ return 8 if cfg.optimizations.fp8 else 0
- raise ValueError
+ @property
+ def activation_checkpointing(self):
+ return cfg.trainer.activation_checkpointing
+
+ @property
+ def gradient_checkpointing(self):
+ return cfg.trainer.gradient_checkpointing
+ @property
+ def lora_policy(self):
+ include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
+ exclude = []
+
+ if self.arch_type == "llama":
+ include = ["self_attn", "mlp"] # target only the attention + mlp
+ exclude = ["self_attn.k_proj"] # common literature says to ignore it
+ if self.arch_type == "retnet":
+ include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
+ exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
+
+ return dict(include=include, exclude=exclude)
+
+# should be renamed to Adapters
+@dataclass()
+class LoRA:
+ name: str = "lora" # vanity name
+ # to-do: find sane default values
+ rank: int = 128 # rank for the LoRA
+ alpha: int = 128 # rank for the LoRA
+ training: bool = True #
+ embeddings: bool = False # train the embedding too
+ parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
+ rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
+
+ @property
+ def full_name(self):
+ name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
+ return "-".join(name)
+
+ # actually not needed anymore
+ def active_level( self, level ):
+ if not self.rvq_levels:
+ return True
+ return level in self.rvq_levels
+
@dataclass()
class Hyperparameters:
- batch_size: int = 8
- gradient_accumulation_steps: int = 32
- gradient_clipping: int = 100 # to be implemented in the local backend
+ batch_size: int = 8 # number of samples per training batch
+ gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
+ gradient_clipping: int | float = 10 # largest size a gradient norm can be
- optimizer: str = "Adamw"
- learning_rate: float = 3.25e-4
+ optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
+ optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
+
+ learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
+ warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
- scheduler_type: str = "" # to be implemented in the local backend
- scheduler_params: dict = field(default_factory=lambda: {})
+ scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
+ scheduler_type: str = "" # deprecated
+ scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
+
+ autotune: bool = False # to do deepspeed's autotuning
+ autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
+
+ torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
+ torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
@dataclass()
class Evaluation:
- batch_size: int = 64
- frequency: int = 250
- size: int = 64
-
+ batch_size: int = 64 # number of samples per batch during eval / val
+ frequency: int = 250 # do eval / val every X iterations
+ size: int = 64 # number of samples to generate during eval / val
+
steps: int = 500
- temperature: float = 1.0
+ temperature: float = 1.0 # AR temp for inferencing
+
+ load_disabled_engines: bool = True # see the other load_disabled_engines
@dataclass()
class DeepSpeed:
- zero_optimization_level: int = 0
- use_compression_training: bool = False
+ zero_optimization_level: int = 0 # doesn't seem to work
+ use_compression_training: bool = False # cope
+ compression_bits: int = 8 # cope
+ inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
+
+ amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
- def get_ds_cfg(self, model):
- weights = [ name[0] for name in model.named_parameters() ]
- bits = 8
+ config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
- scheduler_params = {}
- for k in cfg.hyperparameters.scheduler_params:
- scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
+ @cached_property
+ def ds_cfg(self):
+ optimizer_params = cfg.hyperparameters.optimizer_params
+
+ if 'lr' not in optimizer_params:
+ optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
- if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
+ scheduler_params = cfg.hyperparameters.scheduler_params
+ if 'warmup_num_steps' not in scheduler_params:
+ scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
+
+ if 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations
+ autotune_params = cfg.hyperparameters.autotune_params
+
+ if "enabled" not in autotune_params:
+ autotune_params['enabled'] = True
+
+ if "results_dir" not in autotune_params:
+ autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" )
+
+ if "exps_dir" not in autotune_params:
+ autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" )
+
+ # DeepSpeed fp16 is incompatible with its AMP
+ if cfg.trainer.weight_dtype.lower() == "float16":
+ self.amp = False
+
+ # disable local AMP
+ if self.amp:
+ cfg.trainer.amp = False
+
ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
"optimizer": {
"type": cfg.hyperparameters.optimizer,
- "params": {
- "lr": cfg.hyperparameters.learning_rate,
- }
- },
+ "params": optimizer_params,
+ } if not cfg.hyperparameters.torch_optimizer else None,
"scheduler": {
- "type": cfg.hyperparameters.scheduler_type,
+ "type": cfg.hyperparameters.scheduler,
"params": scheduler_params,
- } if cfg.hyperparameters.scheduler_type != "" else None,
+ } if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
- "enabled": True,
- "auto_cast": True,
- } if cfg.trainer.weight_dtype.lower() == "float16" else None,
- "bf16": {
- "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
+ "enabled": cfg.trainer.weight_dtype.lower() == "float16",
+ "auto_cast": True, # ???
+ "loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
},
+ "bf16": {
+ "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
+ },
+ "amp": {
+ "enabled": self.amp,
+ },
+ "autotuning": autotune_params if cfg.hyperparameters.autotune else None,
"compression_training": {
"weight_quantization": {
"shared_parameters":{
@@ -214,7 +324,7 @@ class DeepSpeed:
"quantize_verbose": True,
"quantization_type": "symmetric",
"rounding": "nearest",
- "quantize_weight_in_forward": True,
+ "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
@@ -223,30 +333,38 @@ class DeepSpeed:
"different_groups": {
"wq1": {
"params": {
- "start_bits": bits,
- "target_bits": bits,
+ "start_bits": self.compression_bits,
+ "target_bits": self.compression_bits,
"quantization_period": 0
},
- "modules": weights
+ "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
}
}
},
"activation_quantization": {
"shared_parameters":{
"enabled": True,
+ "quantizer_kernel": True,
+ "schedule_offset": 0,
+ "quantize_groups": 64,
+ "quantize_verbose": True,
"quantization_type": "symmetric",
- "range_calibration": "dynamic",
- "schedule_offset": 0
+ "rounding": "nearest",
+ "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
+ "fp16_mixed_quantize":{
+ "enabled": False,
+ "quantize_change_ratio": 1
+ }
},
"different_groups": {
"aq1": {
"params": {
- "bits": bits
+ "bits": self.compression_bits,
},
- "modules": weights
+ "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
}
}
- }
+ },
} if self.use_compression_training else None,
"zero_optimization": {
"stage": self.zero_optimization_level,
@@ -264,7 +382,10 @@ class DeepSpeed:
"offload_param": {
"device": "cpu",
"pin_memory": True
- }
+ },
+ "zero_quantized_weights": self.use_compression_training,
+ "zero_hpz_partition_size": world_size(),
+ "zero_quantized_gradients": self.use_compression_training,
} if self.zero_optimization_level > 0 else None,
"comms_logger": {
"enabled": False
@@ -275,113 +396,314 @@ class DeepSpeed:
for k in null_keys:
del ds_cfg[k]
- if os.path.exists("./config/ds_config.json"):
- ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
+ if os.path.exists("./data/ds_config.json"):
+ ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
+ else:
+ ds_cfg.update(self.config)
return ds_cfg
@dataclass()
class Trainer:
- iterations: int = 100_000
+ iterations: int = 1_000_000 # maximum iterations to train
- save_tag: str = "step"
- load_tag: str | None = None
+ save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
+ load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
- save_on_oom: bool = True
- save_on_quit: bool = True
- save_frequency: int = 100
+ save_on_oom: bool = True # save if an OOM error is raised
+ save_on_quit: bool = True # save when quitting training
+
+ export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
+ export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
+
+ save_frequency: int = 100 # frequency to save every X iterations
- load_state_dict: bool = False
- load_states: bool = True
- strict_loading: bool = True
- restart_step_count: bool = False
+ keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
- aggressive_optimizations: bool = False
- check_for_oom: bool = True
+ load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
+ load_states: bool = True #
+ strict_loading: bool = False # sets strict_loading=True when loading the state dict
+ load_module_only: bool = False #
+ restart_step_count: bool = False # clears the training stats when loading a checkpoint
+ resize_modules: bool = False # automatically resizes
- gc_mode: str | None = None
+ activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
+ gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
- weight_dtype: str = "float16"
+ aggressive_optimizations: bool = False # deprecated
+ check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
+ gc_mode: str | None = None # deprecated, but marks when to do GC
+ load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
- backend: str = "deepspeed"
+ weight_dtype: str = "float16" # dtype to have the model under
- deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
+ amp: bool = False # automatic mixed precision
+ ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
+ #scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
+
+ load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
+ no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
+
+ backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
+ deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
- if cfg.trainer.weight_dtype == "bfloat16":
+ if self.weight_dtype == "bfloat16":
return torch.bfloat16
+ if self.weight_dtype == "float8_e5m2":
+ return torch.float8_e5m2
+ if self.weight_dtype == "float8_e4m3fn":
+ return torch.float8_e4m3fn
return torch.float32
+ @cached_property
+ def scale_loss(self):
+ # currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
+ return self.dtype == torch.float16
@dataclass()
class Inference:
- use_vocos: bool = True # artifact from the VALL-E trainer
+ backend: str = "local" # backend to use when inferencing
+ weight_dtype: str = "float32" # dtype to load the model under
+ amp: bool = False # automatic mixed precision during inferencing
+
+ normalize: bool = False # do NOT enable this unless you know exactly what you're doing
+
+ @property
+ def dtype(self):
+ if self.weight_dtype == "float16":
+ return torch.float16
+ if self.weight_dtype == "bfloat16":
+ return torch.bfloat16
+ if self.weight_dtype == "int8":
+ return torch.int8
+ if self.weight_dtype == "float8_e5m2":
+ return torch.float8_e5m2
+ if self.weight_dtype == "float8_e4m3fn":
+ return torch.float8_e4m3fn
+ return torch.float32
@dataclass()
-class BitsAndBytes:
- enabled: bool = False
- injects: bool = False
+class Optimizations:
+ injects: bool = False # overwrites default torch classes (not recommended)
+ replace: bool = False # replaces modules in place with the optimized version (recommended)
+ compile: bool | str = False # runs torch.compile on the model
- linear: bool = False
- embedding: bool = False
+ linear: bool = True # inject/replace linear for BnB
+ embedding: bool = True # inject/replace embedding for BnB
+ optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
+
+ bitsandbytes: bool = False # use bitsandbytes
+ dadaptation: bool = False # use dadaptation optimizer
+ bitnet: bool = False # use bitnet
+ fp8: bool = False # use fp8
+
+ model_offloading: dict | None = None # automatically splits the model over a list of devices
+ # example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
+ # example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
+ # | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
+
+ tensorrt: bool = False
@dataclass()
-class Config(_Config):
- device: str = "cuda"
+class Config(BaseConfig):
+ device: str = "cuda" # target device
+ mode: str = "training" # "inferencing"
+ experimental: bool = False # Debug flag, unused now
dataset: Dataset = field(default_factory=lambda: Dataset)
- models: Models = field(default_factory=lambda: Models)
+ models: dict | list | None = field(default_factory=lambda: [])
+ loras: dict | list | None = field(default_factory=lambda: [])
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
- bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
+ bitsandbytes: dict | list | None = None # deprecated
+ optimizations: Optimizations = field(default_factory=lambda: Optimizations)
- def get_device(self):
- return torch.cuda.current_device() if self.device == "cuda" else self.device
+ tokenizer: str | None = None # tokenizer class
+ tokenizer_path: str = "./tokenizer.json" # tokenizer path
+
+ weights_format: str = "pth" # "pth" | "sft"
+ supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
@property
- def cache_dir(self):
- return ".cache" / self.relpath
+ def model(self):
+ for i, model in enumerate(self.models):
+ if model.training:
+ return model
+
+ return self.models[0] if len(self.models) > 0 else None
+
+ # should be renamed to adapters
+ @property
+ def lora(self):
+ for i, lora in enumerate(self.loras):
+ if lora.training:
+ return lora
+
+ return self.loras[0] if len(self.loras) > 0 else None
+
+ @property
+ def distributed(self):
+ return world_size() > 1
@cached_property
def diskcache(self):
- if self.dataset.cache:
+ if self.yaml_path is not None and self.dataset.cache:
return diskcache.Cache(self.cache_dir).memoize
return lambda: lambda x: x
+ # I don't remember why this is needed
def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__)
+ def load_hdf5( self, write=False ):
+ if hasattr(self, 'hdf5'):
+ self.hdf5.close()
+
+ if self.distributed:
+ self.dataset.hdf5_flag = "r"
+ try:
+ self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
+ except Exception as e:
+ _logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
+ self.dataset.use_hdf5 = False
+
+ # to-do: prune unused keys
+ def format( self, training=True ):
+ if isinstance(self.dataset, type):
+ self.dataset = dict()
+
+ if isinstance(self.models, type):
+ self.models = dict()
+
+ if isinstance(self.loras, type):
+ self.loras = dict()
+
+ if isinstance(self.hyperparameters, type):
+ self.hyperparameters = dict()
+
+ if isinstance(self.evaluation, type):
+ self.evaluation = dict()
+
+ if isinstance(self.trainer, type):
+ self.trainer = dict()
+
+ if isinstance(self.inference, type):
+ self.inference = dict()
+
+ if isinstance(self.optimizations, type):
+ self.optimizations = dict()
+
+ self.dataset = Dataset(**self.dataset)
+ self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
+ self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
+
+ self.models = [ Model(**model) for model in self.models ]
+ self.loras = [ LoRA(**lora) for lora in self.loras ]
+
+ if not self.models:
+ self.models = [ Model() ]
+
+ self.hyperparameters = Hyperparameters(**self.hyperparameters)
+
+ self.evaluation = Evaluation(**self.evaluation)
+
+ self.trainer = Trainer(**self.trainer)
+
+ if not isinstance(self.trainer.deepspeed, type):
+ self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
+
+ self.inference = Inference(**self.inference)
+
+ if self.bitsandbytes is not None:
+ self.optimizations = Optimizations(**self.bitsandbytes)
+ else:
+ self.optimizations = Optimizations(**self.optimizations)
+
+ if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
+ self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
+ self.hyperparameters.scheduler_type = ""
+
+ # do not combine the two
+ if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
+ self.hyperparameters.scheduler = ""
+
+ if self.hyperparameters.scheduler == "":
+ self.hyperparameters.torch_scheduler = True
+
+ if self.trainer.backend == "local" and self.distributed:
+ self.trainer.ddp = True
+
+ if self.trainer.activation_checkpointing is not None:
+ self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
+
+ if not training:
+ self.dataset.use_hdf5 = False
+
+ # load our HDF5 file if requested here
+ if self.dataset.use_hdf5:
+ self.load_hdf5()
+
+ # load tokenizer
+ if cfg.tokenizer == "naive":
+ cfg.tokenizer = NaiveTokenizer()
+ else:
+ try:
+ from transformers import PreTrainedTokenizerFast
+
+ tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
+ if tokenizer_path and not tokenizer_path.exists():
+ tokenizer_path = Path("./data/") / cfg.tokenizer_path
+
+ if tokenizer_path and tokenizer_path.exists():
+ cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
+ else:
+ cfg.tokenizer = NaiveTokenizer()
+ except Exception as e:
+ cfg.tokenizer = NaiveTokenizer()
+ _logger.warning(f"Error while parsing tokenizer: {str(e)}")
+ pass
+
+
+# Preserves the old behavior
+class NaiveTokenizer:
+ def get_vocab( self ):
+ """
+ if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
+ return json.loads( cfg.hdf5['symmap'].asstr()[()] )
+ """
+ return { " ": 0, "": 1, "": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
+
+ @cached_property
+ def _bos_token( self ):
+ return self.get_vocab()[""]
+
+ @cached_property
+ def _eos_token( self ):
+ return self.get_vocab()[""]
+
+ def encode( self, s ):
+ symmap = self.get_vocab()
+ s = s.replace("O", "0")
+ s = [f""] + [ p if p in symmap else " " for p in s ] + [f""]
+ return [*map(symmap.get, s)]
+
+_logger = logging.getLogger(__name__)
cfg = Config.from_cli()
-# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves
-cfg.dataset = Dataset(**cfg.dataset)
-cfg.models = Models(**cfg.models)
-cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
-cfg.evaluation = Evaluation(**cfg.evaluation)
-cfg.trainer = Trainer(**cfg.trainer)
-cfg.inference = Inference(**cfg.inference)
-cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
-
-cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
-
-# cached_property stopped working...
-if cfg.dataset.use_hdf5:
- try:
- cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
- except Exception as e:
- print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
- cfg.dataset.use_hdf5 = False
-
-if not cfg.dataset.use_hdf5:
- cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
- cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
+# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
+try:
+ cfg.format()
+except Exception as e:
+ _logger.error(f"Error while parsing config YAML: {str(e)}")
+ raise e # throw an error because I'm tired of silent errors messing things up for me
if __name__ == "__main__":
- print(cfg)
\ No newline at end of file
+ print(cfg)
diff --git a/image_classifier/data.py b/image_classifier/data.py
index 51fdf2c..8b5f011 100755
--- a/image_classifier/data.py
+++ b/image_classifier/data.py
@@ -1,16 +1,19 @@
# todo: clean this mess up
import copy
-# import h5py
+import h5py
import json
import logging
-#import numpy as np
+import numpy as np
import os
import random
import torch
-import math
+import itertools
from .config import cfg
+from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
+from .utils.distributed import global_rank, local_rank, world_size
+from .utils.io import torch_save, torch_load
from collections import defaultdict
from functools import cache, cached_property
@@ -20,23 +23,57 @@ from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.utils.rnn import pad_sequence
+
+from PIL import Image, ImageDraw
import torchvision.transforms as transforms
+
from tqdm.auto import tqdm
-
-from PIL import Image
-
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
-@cache
+# to-do: clean up this symmap mess
def get_symmap():
- return { " ": 0, "": 1, "": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
+ return cfg.tokenizer.get_vocab()
-@cache
-def _get_symbols( content ):
- content = content.replace("O", "0")
- return [f""] + [ p for p in content ] + [f""]
+def tokenize( s ):
+ if isinstance( s, list ):
+ s = "".join( s )
+ return cfg.tokenizer.encode( s )
+
+"""
+def _replace_file_extension(path, suffix):
+ return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
+
+def _get_hdf5_path(path):
+ # to-do: better validation
+ return str(path)
+
+def _get_hdf5_paths( data_dir, type="training", validate=False ):
+ data_dir = str(data_dir)
+
+ key = f"/{type}/{_get_hdf5_path(data_dir)}"
+
+ return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else []
+
+def _get_paths_of_extensions( path, validate=False ):
+ if isinstance(path, str):
+ path = Path(path)
+
+ return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
+
+def _interleaved_reorder(l, fn):
+ groups = defaultdict(list)
+ for e in l:
+ groups[fn(e)].append(e)
+ groups = {k: groups[k] for k in sorted(groups)}
+ for interleaved in zip_longest(*groups.values()):
+ for value in interleaved:
+ if value is not None:
+ yield value
+"""
class Dataset(_Dataset):
def __init__(
@@ -44,43 +81,90 @@ class Dataset(_Dataset):
paths,
width=300,
height=80,
+ stacks=0,
symmap=get_symmap(),
training=False,
):
super().__init__()
-
self._head = None
-
- self.paths = paths
+ self.sampler = None
self.width = width
self.height = height
-
+ self.stacks = stacks
+
+ self.paths = paths
+ self.image_dtype = cfg.trainer.dtype
self.symmap = symmap
+
self.training = training
+ self.dataset_type = "training" if self.training else "validation"
+ self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.transform = transforms.Compose([
- #transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
+ transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
- @cached_property
- def symbols(self):
- return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths]))
+ # to-do: do not do validation if there's nothing in the validation
+ # this just makes it be happy
+ if len(self.dataset) == 0:
+ self.dataset = cfg.dataset.training
+ # split dataset accordingly per GPU
+ if cfg.distributed and self.training:
+ self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
+
+ if len(self.paths) == 0:
+ raise ValueError(f"No valid path is found for {self.dataset_type}")
+
+ @cached_property
+ def sampler_state_dict_path(self):
+ return cfg.rel_path / f"sampler.rank{global_rank()}.pt"
+
+ def save_state_dict(self, path = None):
+ """
+ if path is None:
+ path = self.sampler_state_dict_path
+ if self.sampler is not None:
+ state_dict = self.sampler.get_state()
+ elif self.samplers is not None:
+ state_dict = {
+ "samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
+ }
+ torch_save(state_dict, path)
+ """
+ return
+
+ def load_state_dict(self, path = None):
+ """
+ if path is None:
+ path = self.sampler_state_dict_path
+
+ if not path.exists():
+ return
+
+ state_dict = torch_load(path)
+ if self.sampler is not None:
+ state_dict = self.sampler.set_state(state_dict)
+ else:
+ for name, sampler in state_dict["samplers"].items():
+ if name not in self.samplers:
+ continue
+ self.samplers[name].set_state( sampler )
+ """
+ return
def __getitem__(self, index):
path = self.paths[index]
+ tokens = tokenize( path.stem.upper() )
+ text = torch.tensor( tokens ).to(dtype=torch.uint8)
- # stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
- try:
- text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
- except Exception as e:
- print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
- raise e
+ image = Image.open(path).convert('RGB')
+ width, height = image.size
- image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
+ image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
return dict(
index=index,
@@ -98,11 +182,6 @@ class Dataset(_Dataset):
def __len__(self):
return min(len(self.paths), self._head or len(self.paths))
- def pin_memory(self):
- self.text = self.text.pin_memory()
- self.image = self.image.pin_memory()
- return self
-
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
@@ -111,21 +190,28 @@ def collate_fn(samples: list[dict]):
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
- #np.random.seed(worker_seed)
+ np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dataloader(dataset, training):
+ kwargs = dict(
+ shuffle=True,
+ batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
+ drop_last=training,
+ sampler=dataset.sampler if training else None,
+ ) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
+ batch_sampler=dataset.sampler,
+ )
+
return DataLoader(
dataset=dataset,
- batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
- shuffle=True, # training
- drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
- persistent_workers=cfg.dataset.workers > 0,
- pin_memory=False, # True,
+ persistent_workers=cfg.dataset.workers > 1,
+ pin_memory=False,
worker_init_fn=_seed_worker,
+ **kwargs,
)
def _load_train_val_paths( val_ratio=0.1 ):
@@ -133,8 +219,8 @@ def _load_train_val_paths( val_ratio=0.1 ):
train_paths = []
val_paths = []
- print(cfg.dataset.training)
for data_dir in cfg.dataset.training:
+ paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
@@ -146,12 +232,13 @@ def _load_train_val_paths( val_ratio=0.1 ):
val_len = math.floor(len(train_paths) * val_ratio)
train_len = math.floor(len(train_paths) * (1 - val_ratio))
- print(val_len, train_len)
-
val_paths = train_paths[:-val_len]
train_paths = train_paths[:train_len]
else:
+ paths = []
+
for data_dir in cfg.dataset.validation:
+ paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
@@ -169,7 +256,6 @@ def _load_train_val_paths( val_ratio=0.1 ):
return train_paths, val_paths
-@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
@@ -187,10 +273,10 @@ def create_datasets():
return train_dataset, val_dataset
-
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
+ # deepcopy is slow
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.training_(False)
@@ -200,8 +286,6 @@ def create_train_val_dataloader():
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.symmap))
-
-
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
@@ -210,11 +294,305 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
+# parse dataset into better to sample metadata
"""
-if __name__ == "__main__":
- create_dataset_hdf5()
+def create_dataset_metadata( skip_existing=True ):
+ symmap = get_symmap()
+
+ root = str(cfg.data_dir)
+ metadata_root = str(cfg.metadata_dir)
- train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
- sample = train_dl.dataset[0]
- print(sample)
-"""
+ cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
+
+ def add( dir, type="training", audios=True, texts=True ):
+ name = str(dir)
+ name = name.replace(root, "")
+
+ speaker_name = name
+
+ metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
+ metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
+
+ try:
+ metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
+ except Exception as e:
+ metadata = {}
+
+ if not os.path.isdir(f'{root}/{name}/'):
+ return
+
+ # tqdm.write(f'{root}/{name}')
+ files = os.listdir(f'{root}/{name}/')
+
+ # grab IDs for every file
+ ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
+
+ wrote = False
+
+ for id in tqdm(ids, desc=f"Processing {name}"):
+ try:
+ quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
+
+ if audios and not quant_path.exists():
+ continue
+
+ key = f'{type}/{speaker_name}/{id}'
+
+ if skip_existing and id in metadata:
+ continue
+
+ wrote = True
+
+ if id not in metadata:
+ metadata[id] = {}
+
+ utterance_metadata = {}
+ if audios:
+ # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
+ dac = np.load(quant_path, allow_pickle=True)[()]
+ qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
+
+ if "text" in dac["metadata"]:
+ utterance_metadata["text"] = dac["metadata"]["text"]
+ if "phonemes" in dac["metadata"]:
+ utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
+ if "language" in dac["metadata"]:
+ utterance_metadata["language"] = dac["metadata"]["language"]
+ if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
+ utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
+
+ for k, v in utterance_metadata.items():
+ metadata[id][k] = v
+
+ except Exception as e:
+ tqdm.write(f'Error while processing {id}: {e}')
+
+ if wrote:
+ with open(str(metadata_path), "w", encoding="utf-8") as f:
+ f.write( json.dumps( metadata ) )
+
+ # training
+ for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
+ add( data_dir, type="training" )
+
+ # validation
+ for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
+ add( data_dir, type="validation" )
+
+ # noise
+ for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
+ add( data_dir, type="noise", texts=False )
+
+# parse yaml to create an hdf5 file
+def create_dataset_hdf5( skip_existing=True ):
+ cfg.dataset.use_hdf5 = True
+ cfg.load_hdf5(write=True)
+ hf = cfg.hdf5
+
+ symmap = get_symmap()
+
+ root = str(cfg.data_dir)
+ metadata_root = str(cfg.metadata_dir)
+
+
+ def add( dir, type="training", audios=True, texts=True ):
+ name = str(dir)
+ name = name.replace(root, "")
+
+ # yucky
+ speaker_name = name
+ if "LibriTTS-R" in speaker_name:
+ speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
+
+ metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
+ metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
+
+ metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
+
+ if not os.path.isdir(f'{root}/{name}/'):
+ return
+
+ files = os.listdir(f'{root}/{name}/')
+
+ # grab IDs for every file
+ ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
+
+ for id in tqdm(ids, desc=f"Processing {name}"):
+ try:
+ quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
+ text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
+
+ if not quant_exists:
+ continue
+
+ key = f'{type}/{speaker_name}/{id}'
+
+ if skip_existing and key in hf:
+ continue
+
+ group = hf.create_group(key) if key not in hf else hf[key]
+
+ if id not in metadata:
+ metadata[id] = {}
+
+ utterance_metadata = {}
+
+ # audio
+ if audios:
+ dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
+ qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
+
+ if "text" in dac["metadata"]:
+ utterance_metadata["text"] = dac["metadata"]["text"]
+ if "phonemes" in dac["metadata"]:
+ utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
+ if "language" in dac["metadata"]:
+ utterance_metadata["language"] = dac["metadata"]["language"]
+ if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
+ utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
+
+ if "audio" not in group:
+ group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
+
+ # text
+ if texts:
+ if not utterance_metadata and text_exists:
+ utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
+
+ phn = "".join(utterance_metadata["phonemes"])
+ phn = cfg.tokenizer.encode(phn)
+ phn = np.array(phn).astype(np.uint8)
+
+ if "text" not in group:
+ group.create_dataset('text', data=phn, compression='lzf')
+
+ for k, v in utterance_metadata.items():
+ group.attrs[k] = v
+ metadata[id][k] = v
+
+ except Exception as e:
+ tqdm.write(f'Error while processing {id}: {e}')
+
+ with open(str(metadata_path), "w", encoding="utf-8") as f:
+ f.write( json.dumps( metadata ) )
+
+ # training
+ for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
+ add( data_dir, type="training" )
+
+ # validation
+ for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
+ add( data_dir, type="validation" )
+
+ # noise
+ for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
+ add( data_dir, type="noise", texts=False )
+
+ # write symmap
+ if "symmap" in hf:
+ del hf['symmap']
+
+ hf.create_dataset('symmap', data=json.dumps(symmap))
+ hf.close()
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser("Save trained model to path.")
+ parser.add_argument("--action", type=str)
+ parser.add_argument("--tasks", type=str)
+ args, unknown = parser.parse_known_args()
+
+ task = args.action
+
+ cfg.dataset.workers = 1
+
+ if args.action == "hdf5":
+ create_dataset_hdf5()
+ elif args.action == "list-dataset":
+ dataset = []
+ for group in os.listdir(cfg.data_dir):
+ for name in os.listdir(cfg.data_dir / group):
+ if len(os.listdir(cfg.data_dir / group / name)) == 0:
+ continue
+ dataset.append(f'{group}/{name}')
+
+ _logger.info(json.dumps(dataset))
+ elif args.action == "metadata":
+ create_dataset_metadata()
+ elif args.action == "sample":
+ train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
+
+ samples = {
+ "training": [ next(iter(train_dl)), next(iter(train_dl)) ],
+ "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
+ #"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
+ }
+
+ Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
+
+ for k, v in samples.items():
+ for i in range(len(v)):
+ for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
+ try:
+ decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
+ except Exception as e:
+ _logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
+ try:
+ decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
+ except Exception as e:
+ _logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
+ v[i]['proms'][j] = v[i]['proms'][j].shape
+ v[i]['resps'][j] = v[i]['resps'][j].shape
+
+ for k, v in samples.items():
+ for i in range(len(v)):
+ _logger.info(f'{k}[{i}]: {v[i]}')
+ elif args.action == "validate":
+ train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
+
+ missing = set()
+
+ for i in range(len( train_dl.dataset )):
+ batch = train_dl.dataset[i]
+
+ text = batch['text']
+ phonemes = batch['metadata']['phonemes']
+
+ decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
+ for i, token in enumerate(decoded):
+ if token != "":
+ continue
+
+ phone = phonemes[i]
+
+ _logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
+
+ missing |= set([phone])
+
+ _logger.info( f"Missing tokens: {missing}" )
+
+
+ elif args.action == "tasks":
+ index = 0
+ cfg.dataset.tasks_list = args.tasks.split(",")
+
+ train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
+ batch = next(iter(train_dl))
+
+ for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
+ if task not in cfg.dataset.tasks_list:
+ continue
+
+ _logger.info( f'{text} {task} {cfg.model.resp_levels}')
+ _logger.info( f'{proms.shape} {resps.shape}' )
+
+ tokens = 0
+ tokens += sum([ text.shape[0] for text in batch["text"] ])
+ tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
+ _logger.info( f'{tokens}' )
+
+ decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
+ decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
+ break
+"""
\ No newline at end of file
diff --git a/image_classifier/engines/__init__.py b/image_classifier/engines/__init__.py
index f0879ec..99c10ec 100755
--- a/image_classifier/engines/__init__.py
+++ b/image_classifier/engines/__init__.py
@@ -1,6 +1,6 @@
from ..config import cfg
-from ..utils.distributed import fix_unset_envs
+from ..utils.distributed import fix_unset_envs, ddp_model
fix_unset_envs()
if cfg.trainer.backend == "deepspeed":
@@ -8,4 +8,211 @@ if cfg.trainer.backend == "deepspeed":
elif cfg.trainer.backend == "local":
from .base import Engine
-from .base import Engines, TrainFeeder, default_feeder
\ No newline at end of file
+from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
+
+from ..models import get_models, get_model
+from ..utils import wrapper as ml
+from ..utils.io import torch_save, torch_load, pick_path
+from ..models.lora import apply_lora, lora_load_state_dict
+
+import torch
+import re
+import logging
+
+_logger = logging.getLogger(__name__)
+
+deepspeed_available = False
+try:
+ import deepspeed
+ deepspeed_available = True
+except Exception as e:
+ pass
+
+from functools import cache
+
+@cache
+def load_engines(training=True, **model_kwargs):
+ models = get_models(cfg.models, training=training, **model_kwargs)
+ engines = dict()
+
+ for name, model in models.items():
+ state = None
+ stats = None
+ lora = None
+
+ inferencing = cfg.mode == "inferencing" or not model.config.training or not training
+ backend = cfg.inference.backend if inferencing else cfg.trainer.backend
+ loads_state_dict = cfg.trainer.load_state_dict # or inferencing
+
+ checkpoint_path = cfg.ckpt_dir / name / "latest"
+ # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
+ load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
+
+ # actually use the lora-specific checkpoint if available
+ if cfg.lora is not None:
+ checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
+
+ # to handle the issue of training with deepspeed, but inferencing with local
+ if checkpoint_path.exists() and backend == "local":
+ tag = open(checkpoint_path).read()
+ checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
+
+ if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
+ _logger.warning(f"Checkpoint missing, but weights found: {load_path}")
+ loads_state_dict = True
+
+ # load state early
+ if loads_state_dict:
+ state = torch_load(load_path, device=cfg.device)
+
+ # check if config is defined in state, and re-initialize the model
+ if "config" in state and False:
+ _logger.warning("Model config definition in weights, re-loading...")
+ config_state = state["config"]
+ model = get_model( config=cfg.model.__class__( *config_state ), training=training )
+
+ hyper_config = model.config
+
+ optimizer = None
+ lr_scheduler = None
+
+ dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
+ amp = cfg.inference.amp if inferencing else cfg.trainer.amp
+ ddp = cfg.trainer.ddp
+
+ engine_class = LocalEngine if backend == "local" else Engine
+
+ # apply model replacers
+ if cfg.optimizations.replace and cfg.optimizations.linear:
+ model.model = ml.replace_linear( model.model )
+
+ if cfg.optimizations.replace and cfg.optimizations.embedding:
+ model.model = ml.replace_embedding( model.model )
+
+ for lora in cfg.loras:
+ model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
+
+ if inferencing:
+ model.config.training = False
+
+ if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
+ optimizer_class = None
+ scheduler_class = None
+
+ params = {
+ "lr": cfg.hyperparameters.learning_rate,
+ }
+ if cfg.hyperparameters.optimizer.lower() == "adamw":
+ params["betas"] = (0.9, 0.96)
+ params["eps"] = 1e-07
+ params["weight_decay"] = 0.01
+
+ # for dadaptation since it has Adam only
+ if ml.AdamW == ml.Adam:
+ params["decouple"] = True
+
+ optimizer_class = ml.AdamW
+ elif cfg.hyperparameters.optimizer.lower() == "sgd":
+ optimizer = ml.SGD
+ elif cfg.hyperparameters.optimizer.lower() == "prodigy":
+ optimizer_class = ml.Prodigy
+
+ params['d_coef'] = params['lr']
+ params['lr'] = 1.0
+ elif cfg.hyperparameters.optimizer.lower() == "adagrad":
+ optimizer_class = ml.Adagrad
+ else:
+ raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
+
+ params.update(cfg.hyperparameters.optimizer_params)
+
+ optimizer = optimizer_class(
+ [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
+ **params,
+ )
+
+ if cfg.hyperparameters.scheduler.lower() == "schedulefree":
+ if cfg.hyperparameters.optimizer.lower() == "adamw":
+ scheduler_class = ml.schedulefree.AdamWScheduleFree
+ elif cfg.hyperparameters.optimizer.lower() == "sgd":
+ scheduler_class = ml.schedulefree.SGDScheduleFree
+ else:
+ raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
+
+ optimizer = scheduler_class(
+ [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
+ lr = params['lr'],
+ warmup_steps = cfg.hyperparameters.warmup_steps
+ )
+
+ """
+ # set up our LR scheduler here
+ """
+
+ if inferencing:
+ optimizer = None
+ lr_scheduler = None
+
+ # load state dict if requested / required
+ if loads_state_dict:
+ # state dict is not just the module, extract the extra trainer details
+ if "stats" in state:
+ stats = state["stats"]
+
+ # do not load stats if we're training a LoRA
+ if cfg.lora is not None or cfg.trainer.restart_step_count:
+ stats = None
+
+ if "module" in state:
+ state = state["module"]
+
+ model.load_state_dict(state, strict=cfg.trainer.strict_loading)
+
+ # load lora weights if exists
+ if cfg.lora is not None:
+ lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
+ if lora_path.exists():
+ _logger.info( f"Loaded LoRA state dict: {lora_path}" )
+
+ state = torch_load(lora_path, device=cfg.device)
+ state = state['lora' if 'lora' in state else 'module']
+ lora_load_state_dict( model, state )
+
+ # wrap if DDP is requested
+ if ddp:
+ model = ddp_model(model)
+ # wrap optimization class
+ elif cfg.optimizations.compile:
+ model = ml.compile_model(model, backend=cfg.optimizations.compile)
+ # deepspeed inferencing
+ elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
+ engine_class = LocalEngine
+ model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
+
+ # use base engine if requested
+ engines[name] = engine_class(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+
+ hyper_config=hyper_config,
+ stats=stats
+ )
+
+
+ engines = Engines(engines)
+ engines.setup()
+
+ # this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
+ if not cfg.trainer.load_state_dict:
+ engines.load_checkpoint(training=not inferencing)
+
+ # freeze requested params
+ for name, engine in engines.items():
+ engine.freeze(freeze_all=False)
+
+ # split models over requested devices
+ if cfg.optimizations.model_offloading:
+ engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
+
+ return engines
diff --git a/image_classifier/engines/base.py b/image_classifier/engines/base.py
index 24199ed..a077280 100755
--- a/image_classifier/engines/base.py
+++ b/image_classifier/engines/base.py
@@ -28,7 +28,9 @@ def default_feeder(engine, batch):
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
-from ..utils.distributed import init_distributed, distributed_initialized
+from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
+from ..utils.io import torch_save, torch_load
+from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
import logging
import time
@@ -39,40 +41,65 @@ import os
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
+from functools import cached_property
from .base import TrainFeeder
+from ..utils import wrapper as ml
_logger = logging.getLogger(__name__)
-if not distributed_initialized() and cfg.trainer.backend == "local":
- def _nop():
- ...
- fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group
- init_distributed(fn)
+if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
+ init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch
-# to-do: implement lr_sheduling
class Engine():
def __init__(self, *args, **kwargs):
- self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
+ if 'hyper_config' in kwargs:
+ self.hyper_config = kwargs['hyper_config']
+ kwargs.pop("hyper_config")
+
+ self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
- self.global_steps = 0
- self.micro_steps = 0
- self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
+ self.global_steps = kwargs.pop("global_steps", 0)
+ self.micro_steps = kwargs.pop("micro_steps", 0)
+ self.global_samples = kwargs.pop("global_samples", 0)
+ self.tokens_processed = kwargs.pop("tokens_processed", 0)
- def freeze(self):
- for p in self.module.parameters():
- if p.requires_grad:
- p.requires_grad_(False)
- self._frozen_params.add(p)
+ self._frozen_params = set()
+
+ self.max_nan_losses = 8
+ self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
+
+ self.current_batch_size = 0
+ self._global_grad_norm = None
+
+ def freeze(self, freeze_all=True):
+ # set to freeze
+ if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
+ raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
+
+ # freeze non-LoRA params if requested
+ if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
+ return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
+
+ for name, param in self.module.named_parameters():
+ if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
+ param.requires_grad_(False)
+ self._frozen_params.add(param)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
+ @property
+ def _training(self):
+ if not hasattr(self, "hyper_config"):
+ return True
+ return self.hyper_config.training
+
@property
def global_step(self):
return self.global_steps
@@ -81,8 +108,17 @@ class Engine():
def micro_step(self):
return self.micro_steps
- def train_batch_size(self):
- return cfg.hyperparameters.batch_size
+ @property
+ def batch_size(self):
+ return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
+
+ @property
+ def gradient_accumulation_steps(self):
+ return cfg.hyperparameters.gradient_accumulation_steps
+
+ @property
+ def gradient_clipping(self):
+ return cfg.hyperparameters.gradient_clipping
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
@@ -91,42 +127,74 @@ class Engine():
return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ):
- save_path = save_dir / tag / "state.pth"
- save_path.parent.mkdir(parents=True, exist_ok=True)
- torch.save({
- "global_step": self.global_step,
- "micro_step": self.micro_step,
- "module": self.module.state_dict(),
- "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
- "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
- }, save_path)
+ if is_global_leader():
+ module = self.module.state_dict()
- open(save_dir / "latest", 'w').write( tag )
+ # if training lora
+ # this is a separate path to override saving the weights
+ lora = None
+ if cfg.lora is not None:
+ lora, module = lora_get_state_dict( module, split = True )
+ save_dir = cfg.ckpt_dir / cfg.lora.full_name
+
+ save_path = save_dir / tag / f"state.{cfg.weights_format}"
+ save_path.parent.mkdir(parents=True, exist_ok=True)
+
+ torch_save({
+ "module": module,
+ "lora": lora,
+ "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
+ "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
+
+ "stats": {
+ "global_step": self.global_step,
+ "micro_step": self.micro_step,
+ "global_samples": self.global_samples,
+ "tokens_processed": self.tokens_processed,
+ }
+ }, save_path)
+
+ open(save_dir / "latest", 'w').write( tag )
+
+ torch.distributed.barrier()
+
+ def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
+ # override to load the lora instead
+ if cfg.lora is not None:
+ load_dir = cfg.ckpt_dir / cfg.lora.full_name
- def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
if tag is None:
tag_path = load_dir / "latest"
+
if not tag_path.exists():
return
+
tag = open(tag_path).read()
- load_path = load_dir / tag / "state.pth"
+ load_path = load_dir / tag / f"state.{cfg.weights_format}"
+
if not load_path.exists():
return
- state = torch.load(load_path)
- self.global_steps = state['global_step']
- self.micro_steps = state['micro_step']
- self.module.load_state_dict(state['module'])
+ state = torch_load(load_path, device=cfg.device)
+
+ self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
+ self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
+ self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
+ self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
+ self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states:
- self.optimizer.load_state_dict(state['optimizer'])
+ self.optimizer.load_state_dict(state['optimizer']) #, device=cfg.device)
if load_lr_scheduler_states:
- self.lr_scheduler.load_state_dict(state['lr_scheduler'])
+ self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device)
+
+ if 'lora' in state:
+ lora_load_state_dict( self.module, state['lora'] )
def eval(self):
return self.module.eval()
@@ -136,46 +204,80 @@ class Engine():
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
- return self.module
+ if self.optimizer:
+ self.optimizer = self.optimizer.to(*args, **kwargs)
+
+ return self
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
+ @cached_property
+ def device(self):
+ return next(self.module.parameters()).device
+
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)
def backward(self, loss):
+ if self.loss_scaler is not None:
+ return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward()
return (loss / self.gradient_accumulation_steps).backward()
+
def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1
+ self.global_samples += self.batch_size
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
+ torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
+
self.global_steps += 1
- self.optimizer.step()
+ if self.loss_scaler is not None:
+ self.loss_scaler.step(self.optimizer)
+ self.loss_scaler.update()
+ else:
+ self.optimizer.step()
self.optimizer.zero_grad()
+ self._get_grad_norm()
+
+ def _get_grad_norm(self):
+ t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
+ self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None
+
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
- if 'lr' in param_group:
+ if 'd_coeff' in param_group:
+ lrs.append(param_group['d_coeff'])
+ elif 'lr' in param_group:
lrs.append(param_group['lr'])
return lrs
def set_lr(self, lr):
for param_group in self.optimizer.param_groups:
- if 'lr' in param_group:
+ if 'd_coeff' in param_group:
+ param_group['d_coeff'] = lr
+ elif 'lr' in param_group:
param_group['lr'] = lr
def get_global_grad_norm(self):
- return 0.0
+ return self._global_grad_norm
def traverse(self, *args, **kwargs):
- self.forward(*args, **kwargs)
+ with ml.autocast():
+ self.forward(*args, **kwargs)
+
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
+ if torch.isnan(loss).any():
+ self.max_nan_losses = self.max_nan_losses - 1
+ if self.max_nan_losses < 0:
+ raise RuntimeError("Too many NaN losses detected.")
+
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
@@ -194,6 +296,8 @@ class Engines(dict[str, Engine]):
def setup(self):
self._global_step = 0
self._micro_step = 0
+ self._batch_size = 0
+ self._global_samples = 0
@property
def global_step(self):
@@ -203,6 +307,14 @@ class Engines(dict[str, Engine]):
def micro_step(self):
return self._micro_step
+ @property
+ def batch_size(self):
+ return self._batch_size
+
+ @property
+ def global_samples(self):
+ return self._global_samples
+
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
@@ -213,6 +325,50 @@ class Engines(dict[str, Engine]):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
+ def export(self, userdata={}, callback=None, dtype=None, format=None):
+ if not format:
+ format = cfg.weights_format
+ format = format.lower()
+
+ if dtype is None:
+ dtype = cfg.trainer.dtype
+
+ for name, engine in self.items():
+ module = engine.module.state_dict()
+ lora = None
+ save_path = cfg.ckpt_dir / name / f"fp32.{format}"
+ config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
+
+ # safety
+ for k, v in module.items():
+ module[k] = v.to(dtype)
+
+ if cfg.lora is not None:
+ lora, module = lora_get_state_dict( module, split = True )
+ save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
+
+ state_dict = {
+ 'module': module,
+ 'lora': lora,
+ "stats": {
+ "global_step": engine.global_step,
+ "micro_step": engine.micro_step,
+ "global_samples": engine.global_samples,
+ "tokens_processed": engine.tokens_processed,
+ },
+ "userdata": userdata,
+ "config": config.__dict__
+ }
+
+ if lora is None:
+ del state_dict['lora']
+
+ if callback:
+ state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
+
+ torch_save(state_dict, save_path)
+ _logger.info(f"Exported {name} to {save_path}")
+
def save_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.save_tag
@@ -222,47 +378,67 @@ class Engines(dict[str, Engine]):
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
- engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag)
+ if not engine._training:
+ continue
- def load_checkpoint(self, tag=None):
+ save_dir = cfg.ckpt_dir / name
+ try:
+ engine.save_checkpoint(save_dir, tag=tag)
+ except Exception as e:
+ _logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
+
+ # might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
+ if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
+ checkpoints = [ d for d in list(save_dir.glob("*")) if d.is_dir() ]
+ checkpoints.sort(key=lambda x: x.stat().st_mtime)
+ checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
+ for d in checkpoints:
+ if not d.is_dir() or not d.exists():
+ continue
+ _logger.info(f"Removing {d}")
+ for p in d.iterdir():
+ p.unlink()
+ d.rmdir()
+
+ def load_checkpoint(self, tag=None, training=True):
if not tag:
tag = cfg.trainer.load_tag
for name, engine in self.items():
load_dir = cfg.ckpt_dir / name
+
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
- load_optimizer_states=cfg.trainer.load_states,
- load_lr_scheduler_states=cfg.trainer.load_states,
+ load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
+ load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
+ load_module_only=cfg.trainer.load_module_only,
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0
+ engine.mocro_step = 0
+ engine.global_samples = 0
+ engine.tokens_processed = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg.hyperparameters.scheduler_type == "":
self.set_lr(cfg.hyperparameters.learning_rate)
- self._update_global_step()
- self._update_micro_step()
+ self._update()
def set_lr(self, lr):
for engine in self.values():
+ if not engine._training:
+ continue
engine.set_lr(lr)
- def _update_global_step(self):
+ def _update(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
-
- def _update_micro_step(self):
- for engine in self.values():
self._micro_step = max(self._micro_step, engine.micro_step)
-
- def train_batch_size(self):
- batch_size = 0
- for engine in self.values():
- batch_size = max(batch_size, engine.train_batch_size())
+ self._batch_size = max(self._batch_size, engine.batch_size)
+ self._global_samples = max(self._global_samples, engine.global_samples)
def eval(self):
for engine in self.values():
@@ -279,7 +455,10 @@ class Engines(dict[str, Engine]):
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
- def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()):
+ def quit(self):
+ cleanup_distributed()
+
+ def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0
stats: Any = dict()
@@ -287,10 +466,11 @@ class Engines(dict[str, Engine]):
if cfg.trainer.gc_mode == 'step':
do_gc()
- batch = to_device(batch, device)
-
for name, engine in self.items():
- #torch.cuda.synchronize()
+ if not engine._training:
+ continue
+
+ device = engine.device
if cfg.trainer.gc_mode == 'substep':
do_gc()
@@ -298,10 +478,9 @@ class Engines(dict[str, Engine]):
start_time = time.time()
tries = 4
- n_ooms = torch.zeros([], device=cfg.device)
+ n_ooms = torch.zeros([], device=device)
- if cfg.trainer.aggressive_optimizations:
- batch = to_device(batch, device)
+ batch = to_device(batch, device)
if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch )
@@ -311,7 +490,7 @@ class Engines(dict[str, Engine]):
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
- print("Forward", str(e))
+ _logger.error(f"Forward: {str(e)}")
if "out of memory" not in str(e):
self.save_checkpoint()
@@ -329,7 +508,8 @@ class Engines(dict[str, Engine]):
do_gc()
continue
- all_reduce(n_ooms)
+ if world_size() > 1:
+ all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
@@ -340,7 +520,7 @@ class Engines(dict[str, Engine]):
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")
- n_ooms = torch.zeros([], device=cfg.device)
+ n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
@@ -348,10 +528,11 @@ class Engines(dict[str, Engine]):
if not cfg.trainer.check_for_oom:
engine.backward(loss)
else:
+ # to-do: properly handle when one GPU throws an OOM because it just halts
try:
engine.backward(loss)
except RuntimeError as e:
- print("Backwards:", str(e))
+ _logger.error(f"Backwards: {str(e)}")
if "out of memory" not in str(e):
self.save_checkpoint()
@@ -359,10 +540,13 @@ class Engines(dict[str, Engine]):
n_ooms += 1
- all_reduce(n_ooms)
+ if world_size() > 1:
+ all_reduce(n_ooms)
+
if n_ooms.item() > 0:
self.save_checkpoint()
- raise RuntimeError("Out of memory during backwards pass!")
+
+ raise RuntimeError("Out of memory during backwards pass!")
engine.step()
@@ -370,27 +554,36 @@ class Engines(dict[str, Engine]):
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
+ grad_norm = engine.get_global_grad_norm()
+ loss_scale = 1
+ if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
+ loss_scale = engine.optimizer.loss_scale
+
+ if grad_norm is not None:
+ grad_norm /= loss_scale
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
- loss=loss.item(),
+ **engine_stats,
lr=engine.get_lr()[0],
- grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
+ grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
+ loss_scale=loss_scale if loss_scale != 1 else None,
elapsed_time=elapsed_time,
engine_step=engine.global_step,
- **engine_stats,
+ samples_processed=engine.global_samples,
+ tokens_processed=engine.tokens_processed,
)
}
),
)
- self._update_global_step()
- self._update_micro_step()
- stats["batch_size"] = self.train_batch_size() # len(batch["text"])
- stats["elapsed_time"] = total_elapsed_time
- stats["wall_time"] = time.time()
- stats["global_step"] = self.global_step
+ self._update()
+
+ if len(self.keys()) > 1:
+ stats["elapsed_time"] = total_elapsed_time
+
+ stats["it"] = self.global_step
return stats
diff --git a/image_classifier/engines/deepspeed.py b/image_classifier/engines/deepspeed.py
index 8458807..4abe44f 100755
--- a/image_classifier/engines/deepspeed.py
+++ b/image_classifier/engines/deepspeed.py
@@ -25,28 +25,71 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr
from deepspeed.accelerator import get_accelerator
from ..utils.distributed import init_distributed, distributed_initialized
+from ..utils import wrapper as ml
+
+from ..models.lora import freeze_non_lora_weights
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
init_distributed(init_deepspeed_dist)
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
- kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
+ self.hyper_config = None
+ if 'hyper_config' in kwargs:
+ self.hyper_config = kwargs['hyper_config']
+ kwargs.pop("hyper_config")
+
+ kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
+ stats = {
+ "global_step": 0,
+ "micro_step": 0,
+ "global_samples": 0,
+ "tokens_processed": 0,
+ }
+
+ # kwargs['stats'] = None will return None when popped
+ maybe_stats = kwargs.pop('stats', stats)
+ if maybe_stats is not None:
+ stats = maybe_stats
+
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
- def freeze(self):
- for p in self.module.parameters():
- if p.requires_grad:
- p.requires_grad_(False)
- self._frozen_params.add(p)
+ self.global_steps = stats["global_step"]
+ self.micro_steps = stats["micro_step"]
+ self.global_samples = stats["global_samples"]
+ self.tokens_processed = stats["tokens_processed"]
+
+ self.max_nan_losses = 8
+ self.current_batch_size = 0
+
+ def freeze(self, freeze_all=True):
+ # freeze non-LoRA params if requested
+ if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
+ frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
+ for param in frozen_params:
+ self._frozen_params.add( param )
+
+ return
+
+ if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
+ raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
+
+ for name, param in self.module.named_parameters():
+ if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
+ param.requires_grad_(False)
+ self._frozen_params.add(param)
def unfreeze(self):
- for p in self._frozen_params:
- p.requires_grad_(True)
+ for param in self._frozen_params:
+ param.requires_grad_(True)
self._frozen_params.clear()
+
+ @property
+ def _training(self):
+ return self.hyper_config.training
@property
def global_step(self):
@@ -54,7 +97,11 @@ class Engine(DeepSpeedEngine):
@property
def micro_step(self):
- return self.micro_steps
+ return self.micro_steps
+
+ @property
+ def batch_size(self):
+ return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
@@ -66,17 +113,40 @@ class Engine(DeepSpeedEngine):
try:
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
- param_group['lr'] = lr
+ param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
else:
self.optimizer.set_lr(lr)
except Exception as e:
- print(str(e))
+ _logger.warning(str(e))
+
+ # we'll just have to live with the LoRA weights living within our main weights
+ # they're easy to extract anyways
+ def load_checkpoint(self, load_dir, **kwargs ):
+ # override to load the lora instead
+ if cfg.lora is not None:
+ load_dir = cfg.ckpt_dir / cfg.lora.full_name
+
+ return super().load_checkpoint( load_dir, **kwargs )
+
+ def save_checkpoint(self, save_dir, **kwargs ):
+ # override to save the lora instead
+ if cfg.lora is not None:
+ save_dir = cfg.ckpt_dir / cfg.lora.full_name
+
+ return super().save_checkpoint( save_dir, **kwargs )
def traverse(self, *args, **kwargs):
- self.forward(*args, **kwargs)
+ with ml.autocast():
+ self.forward(*args, **kwargs)
+
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
+ if torch.isnan(loss).any():
+ self.max_nan_losses = self.max_nan_losses - 1
+ if self.max_nan_losses < 0:
+ raise RuntimeError("Too many NaN losses detected.")
+
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
diff --git a/image_classifier/export.py b/image_classifier/export.py
index 8ec2dfe..176c2f4 100755
--- a/image_classifier/export.py
+++ b/image_classifier/export.py
@@ -1,31 +1,67 @@
import argparse
import torch
+import torch.nn
from .data import get_symmap
-from .train import load_engines
+from .engines import load_engines
+from .config import cfg
+from .models.lora import lora_get_state_dict
+from .utils.io import torch_save, torch_load
-def load_models():
- models = {}
- engines = load_engines()
- for name in engines:
- model = engines[name].module.cpu()
- models[name] = model
+# yanks a LoRA from the training checkpoint
+def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
+ if dtype is None:
+ dtype = cfg.inference.dtype
- return models
+ format = save_path.stem[1:]
+
+ lora = state_dict["lora"] if "lora" in state_dict else None
+ # should always be included, but just in case
+ if lora is None and "module" in state_dict:
+ lora, module = lora_get_state_dict( state_dict["module"], split = True )
+ state_dict["module"] = module
+
+ if "lora" in state_dict:
+ state_dict["lora"] = None
+
+ # should raise an exception since there's nothing to extract, or at least a warning
+ if not lora:
+ return state_dict
+
+ # save lora specifically
+ # should probably export other attributes, similar to what SD LoRAs do
+ save_path = save_path.parent / f"lora.{format}"
+ torch_save( {
+ "module": lora,
+ "config": cfg.lora.__dict__ if cfg.lora is not None else None,
+ }, save_path )
+
+ return state_dict
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
- parser.add_argument("path")
- args = parser.parse_args()
+ parser.add_argument("--module-only", action='store_true')
+ parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
+ parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
+ args, unknown = parser.parse_known_args()
- models = load_models()
- for name in models:
- model = models[name]
+ if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]:
+ raise Exception(f"Unknown requested format: {args.format}")
- outpath = f'{args.path}/{name}.pt'
- torch.save(model, outpath)
- print(f"Exported {name} to {outpath}")
+ if args.module_only:
+ cfg.trainer.load_module_only = True
+
+ if args.dtype != "auto":
+ cfg.trainer.weight_dtype = args.dtype
+
+ # necessary to ensure we are actually exporting the weights right
+ cfg.inference.backend = cfg.trainer.backend
+
+ engines = load_engines(training=False) # to ignore loading optimizer state
+
+ callback = None
+ engines.export(userdata={"symmap": get_symmap()}, callback=callback, format=args.format)
if __name__ == "__main__":
main()
\ No newline at end of file
diff --git a/image_classifier/inference.py b/image_classifier/inference.py
index 0e70703..dba3f93 100755
--- a/image_classifier/inference.py
+++ b/image_classifier/inference.py
@@ -1,53 +1,103 @@
import torch
+import torchaudio
+import soundfile
+import time
+import logging
+_logger = logging.getLogger(__name__)
+
+from torch import Tensor
+from einops import rearrange
+from pathlib import Path
+
+from .utils import to_device, set_seed, wrapper as ml
+from PIL import Image, ImageDraw
import torchvision.transforms as transforms
-from .config import cfg
-from .export import load_models
-from .data import get_symmap, _get_symbols
+from .config import cfg, Config
+from .models import get_models
+from .engines import load_engines, deepspeed_available
+from .data import get_symmap, tokenize
+
+if deepspeed_available:
+ import deepspeed
class Classifier():
- def __init__( self, width=300, height=80, config=None, ckpt=None, device=cfg.get_device(), dtype="float32" ):
+ def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
+ self.loading = True
+
+ # yes I can just grab **kwargs and forward them here
+ self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
+ self.load_model()
+
+ self.loading = False
+
+ def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
if config:
+ _logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config )
- self.loading = True
+ try:
+ cfg.format( training=False )
+ cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
+ except Exception as e:
+ raise e # throw an error because I'm tired of silent errors messing things up for me
+
+ if amp is None:
+ amp = cfg.inference.amp
+ if dtype is None or dtype == "auto":
+ dtype = cfg.inference.weight_dtype
+ if device is None:
+ device = cfg.device
+
+ cfg.device = device
+ cfg.mode = "inferencing"
+ cfg.trainer.backend = cfg.inference.backend
+ cfg.trainer.weight_dtype = dtype
+ cfg.inference.weight_dtype = dtype
+
self.device = device
-
- if ckpt:
- self.load_model_from_ckpt( ckpt )
- else:
- self.load_model_from_cfg( config )
+ self.dtype = cfg.inference.dtype
+ self.amp = amp
- self.model.eval()
+ self.model_kwargs = {}
- self.width = width
- self.height = height
+ def load_model( self ):
+ load_engines.cache_clear()
+
+ self.engines = load_engines(training=False, **self.model_kwargs)
+ for name, engine in self.engines.items():
+ if self.dtype != torch.int8:
+ engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
+
+ self.engines.eval()
+ self.symmap = get_symmap()
+
+ self.width = 300
+ self.height = 80
self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
+
+ _logger.info("Loaded model")
- self.loading = False
+ @torch.inference_mode()
+ def inference( self, image, temperature=1.0 ):
+ model = None
- def load_model_from_ckpt( self, ckpt ):
- self.ckpt = ckpt
-
- self.model = torch.load(self.ckpt).to(self.device)
-
- def load_model_from_cfg( self, config_path ):
-
- models = load_models()
- for name in models:
- model = models[name]
- self.model = model.to(self.device)
+ for name, engine in self.engines.items():
+ model = engine.module
break
- def inference( self, image, temperature=1.0 ):
- image = self.transform(image).to(self.device)
+ image = self.transform(image).to(self.device).to(self.dtype)
+
+ with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
+ answer = model( image=[image], sampling_temperature=temperature )
+
+ answer = [ "".join(answer) ]
- answer = self.model( image=[image], sampling_temperature=temperature )
answer = answer[0].replace('', "").replace("", "") # it would be better to just slice between these, but I can't be assed
return answer
\ No newline at end of file
diff --git a/image_classifier/models/__init__.py b/image_classifier/models/__init__.py
old mode 100755
new mode 100644
index acfca70..49f5f1b
--- a/image_classifier/models/__init__.py
+++ b/image_classifier/models/__init__.py
@@ -1,18 +1,19 @@
from .base import Model
-def get_model(cfg):
+def get_model(cfg, training=False):
name = cfg.name
model = Model(
n_tokens=cfg.tokens,
n_len=cfg.len,
d_model=cfg.dim,
+ d_resnet=cfg.resnet,
)
- model._cfg = cfg
+ model.config = cfg
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
-def get_models(models):
- return { model.full_name: get_model(model) for model in models }
+def get_models(models, training=False):
+ return { model.full_name: get_model(model, training=training) for model in models }
diff --git a/image_classifier/models/base.py b/image_classifier/models/base.py
index 99b0322..a1af8c0 100755
--- a/image_classifier/models/base.py
+++ b/image_classifier/models/base.py
@@ -12,7 +12,7 @@ from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
-from torchvision.models import resnet18
+from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from ..data import get_symmap
@@ -20,12 +20,12 @@ class Model(nn.Module):
def __init__(
self,
n_tokens: int = 0, # number of token types
- n_len: int = 6, # how long a sequence can be
+ n_len: int = 12, # how long a sequence can be
d_model: int = 512,
+ d_resnet: int = 18,
):
super().__init__()
-
_symmap = get_symmap()
self.symmap = { f'{v}': k for k, v in _symmap.items() }
self.symmap['0'] = ""
@@ -36,8 +36,26 @@ class Model(nn.Module):
self.n_tokens = n_tokens
self.n_len = n_len + 2 # start/stop tokens
self.d_model = d_model
+ self.d_resnet = d_resnet
- self.resnet = resnet18(pretrained=False)
+ ResNet = resnet18
+ if d_resnet == 18:
+ print("Using resnet18")
+ ResNet = resnet18
+ elif d_resnet == 34:
+ print("Using resnet34")
+ ResNet = resnet34
+ elif d_resnet == 50:
+ print("Using resnet50")
+ ResNet = resnet50
+ elif d_resnet == 101:
+ print("Using resnet101")
+ ResNet = resnet101
+ elif d_resnet == 152:
+ print("Using resnet152")
+ ResNet = resnet152
+
+ self.resnet = ResNet(pretrained=False)
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
self.accuracy_metric = MulticlassAccuracy(
@@ -61,33 +79,29 @@ class Model(nn.Module):
sampling_temperature: float = 1.0,
):
- x_list = torch.stack( image, dim=0 )
-
- x = self.resnet( x_list )
- y = x.view(x.size(0), self.n_len, self.n_tokens)
+ logits = self.resnet( torch.stack( image, dim=0 ) )
+ logits = logits.view(logits.size(0), self.n_len, self.n_tokens).permute(1, 0, 2)
- # either of these should do, but my VALL-E forward pass uses this, so might as well keep to it
- # pred = y.argmax(dim=2)
- pred = Categorical(logits=y / sampling_temperature).sample()
-
- answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
+ pred = logits.argmax(dim=2)
if text is not None:
- y_list = rearrange(pad_sequence(text), "t b -> b t")
-
- loss = 0
+ labels = rearrange(pad_sequence(text), "t b -> b t").permute(1, 0)
+ loss = []
for i in range(self.n_len):
- if i >= y_list.shape[1]:
+ if i >= labels.shape[0]:
break
- loss += F.cross_entropy( y[:, i], y_list[:, i] )
-
+ loss.append( F.cross_entropy(logits[i], labels[i]) )
+
self.loss = dict(
- nll=loss
+ nll = sum( loss ) / len( loss ),
)
self.stats = dict(
- acc = self.accuracy_metric( pred, y_list ),
- precision = self.precision_metric( pred, y_list ),
+ acc = self.accuracy_metric( pred, labels ),
+ precision = self.precision_metric( pred, labels ),
)
+
+
+ answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
return answer
\ No newline at end of file
diff --git a/image_classifier/models/lora.py b/image_classifier/models/lora.py
new file mode 100644
index 0000000..87e6d7c
--- /dev/null
+++ b/image_classifier/models/lora.py
@@ -0,0 +1,214 @@
+# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+from functools import partial
+import torch
+import torch.nn.functional as F
+import torch.nn.utils.parametrize as parametrize
+
+from transformers.pytorch_utils import Conv1D
+
+from torch import Tensor, nn
+
+import math
+from typing import Optional, List
+
+from ..utils import passes_policy
+
+# LoRA Linear for replacement
+# Pros: simple, just needs to reuse the replace_linear and copy weights
+# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
+class LoRALinear(nn.Linear):
+ def __init__(
+ self,
+
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+
+ rank: int = 4,
+ alpha: int = 1,
+
+ dropout: float = 0.1,
+ merge_weights: bool = False,
+ **kwargs,
+ ):
+ super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
+
+ self.rank = rank
+ self.alpha = alpha
+ self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
+ self.merge_weights = merge_weights
+ self.merged = False
+ self.enabled = True
+
+ self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
+ self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
+ self.scaling = self.alpha / self.rank
+
+ self.weight.requires_grad = False
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ # super silly but necessary because nn.Linear's constructor calls this
+ if hasattr(self, 'lora_A'):
+ nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
+ nn.init.zeros_( self.lora_B )
+
+ def train(self, mode: bool = True):
+ super().train(mode)
+
+ # training, separate lora from base weights
+ if mode and self.merge_weights and self.merged:
+ self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+
+ # not training, merge lora to base weights
+ if not mode and self.merge_weights and not self.merged:
+ self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ if not self.merged and self.enabled:
+ result = F.linear(x, self.weight, bias=self.bias)
+ result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+ return result
+
+ return F.linear(x, self.weight, bias=self.bias)
+
+ @classmethod
+ def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
+ if device is None:
+ device = layer.weight.device
+ if dtype is None:
+ dtype = layer.weight.dtype
+ return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
+
+# Uses parametrization to inject LoRA weights
+# Pros: should work with any Linears
+# Cons: TBD
+class ParameterizedLoRA(nn.Module):
+ def __init__(
+ self,
+
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+
+ rank: int = 4,
+ alpha: int = 1,
+
+ dropout: float = 0.1,
+
+ device = None,
+ dtype = None
+ ):
+ super().__init__()
+ self.rank = rank
+ self.alpha = alpha
+ self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
+
+ self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
+ self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
+ self.scaling = self.alpha / self.rank
+ self.enabled = True
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
+ nn.init.zeros_( self.lora_B )
+
+ def forward(self, x: torch.Tensor):
+ if self.enabled:
+ return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
+ return x
+
+ @classmethod
+ def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
+ if device is None:
+ device = layer.weight.device
+ if dtype is None:
+ dtype = layer.weight.dtype
+ # swap because we're feeding the output as our input
+ # M$'s LoRA class arranges things to where this isn't necessary
+ return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
+
+ @classmethod
+ def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
+ if device is None:
+ device = layer.weight.device
+ if dtype is None:
+ dtype = layer.weight.dtype
+
+ in_channels, out_channels = layer.weight.shape
+ # swap because we're feeding the output as our input
+ # M$'s LoRA class arranges things to where this isn't necessary
+ return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
+
+def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
+ device = next(model.parameters()).device
+ dtype = next(model.parameters()).dtype
+
+ modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
+
+ for *parent, k in modules:
+ name = '.'.join(parent)
+ layer = getattr( model.get_submodule(name), k )
+
+ if isinstance( layer, nn.Linear ):
+ target = nn.Linear
+ klass = ParameterizedLoRA if use_parametrize else LoRALinear
+ replacer = klass.from_linear
+ elif isinstance( layer, nn.Conv1d ):
+ target = nn.Conv1d
+ klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
+ replacer = klass.from_conv1d
+ elif isinstance( layer, Conv1D ):
+ target = Conv1D
+ klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
+ replacer = klass.from_conv1d
+ else:
+ continue
+
+ replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
+
+ if use_parametrize:
+ parametrize.register_parametrization( layer, "weight", replacement )
+ else:
+ setattr( model.get_submodule(name), k, replacement )
+
+ return enable_lora( model )
+
+def enable_lora( model, mode = True ):
+ for name, module in model.named_modules():
+ if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
+ continue
+ module.enabled = mode
+ return model
+
+def disable_lora( model ):
+ return enable_lora( model, False )
+
+def freeze_non_lora_weights( model, embeddings = False ):
+ frozen_params = []
+
+ for name, param in model.named_parameters():
+ should = 'lora_' in name or (embeddings and "_emb" in name)
+
+ param.requires_grad_(should)
+
+ if not should:
+ frozen_params.append( param )
+
+ return frozen_params
+
+def lora_get_state_dict( state_dict, split = True ):
+ lora = { name: param for name, param in state_dict.items() if "lora_" in name }
+ if not split:
+ return lora
+
+ return lora, { name: param for name, param in state_dict.items() if "lora_" not in name }
+
+def lora_load_state_dict( model, state_dict ):
+ return model.load_state_dict( state_dict, strict = False )
\ No newline at end of file
diff --git a/image_classifier/plot.py b/image_classifier/plot.py
new file mode 100644
index 0000000..4eef13f
--- /dev/null
+++ b/image_classifier/plot.py
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+
+import argparse
+import json
+import re
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pandas as pd
+
+from .config import cfg
+
+def plot(paths, args):
+ dfs = []
+
+ for path in paths:
+ with open(path, "r") as f:
+ text = f.read()
+
+ rows = []
+
+ pattern = r"(\{.+?\})\.\n"
+
+ for row in re.findall(pattern, text, re.DOTALL):
+ try:
+ row = json.loads(row)
+ except Exception as e:
+ continue
+
+ for model in args.models:
+ if f'{model.name}.{args.xs}' not in row:
+ continue
+ rows.append(row)
+ break
+
+ df = pd.DataFrame(rows)
+
+ if "name" in df:
+ df["name"] = df["name"].fillna("train")
+ else:
+ df["name"] = "train"
+
+ df["group"] = str(path.parents[args.group_level])
+ df["group"] = df["group"] + "/" + df["name"]
+
+ dfs.append(df)
+
+ df = pd.concat(dfs)
+
+ if args.min_x is not None:
+ for model in args.models:
+ df = df[args.min_x < df[f'{model.name}.{args.xs}']]
+
+ if args.max_x is not None:
+ for model in args.models:
+ df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
+
+ for gtag, gdf in sorted(
+ df.groupby("group"),
+ key=lambda p: (p[0].split("/")[-1], p[0]),
+ ):
+ for model in args.models:
+ x = f'{model.name}.{args.xs}'
+ for ys in args.ys:
+ y = f'{model.name}.{ys}'
+
+ if gdf[y].isna().all():
+ continue
+
+ if args.min_y is not None:
+ gdf = gdf[args.min_y < gdf[y]]
+ if args.max_y is not None:
+ gdf = gdf[gdf[y] < args.max_y]
+
+ gdf[y] = gdf[y].ewm(10).mean()
+
+ gdf.plot(
+ x=x,
+ y=y,
+ label=f"{y}",
+ ax=plt.gca(),
+ marker="x" if len(gdf) < 100 else None,
+ alpha=0.7,
+ )
+
+ plt.gca().legend(
+ loc="center left",
+ fancybox=True,
+ shadow=True,
+ bbox_to_anchor=(1.04, 0.5),
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--xs", default="engine_step")
+ parser.add_argument("--ys", nargs="+", default="")
+ parser.add_argument("--model", nargs="+", default="*")
+
+ parser.add_argument("--min-x", type=float, default=-float("inf"))
+ parser.add_argument("--min-y", type=float, default=-float("inf"))
+ parser.add_argument("--max-x", type=float, default=float("inf"))
+ parser.add_argument("--max-y", type=float, default=float("inf"))
+
+ parser.add_argument("--filename", default="log.txt")
+ parser.add_argument("--group-level", default=1)
+ args, unknown = parser.parse_known_args()
+
+ path = cfg.rel_path / "logs"
+ paths = path.rglob(f"./*/{args.filename}")
+
+ args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
+
+ if args.ys == "":
+ args.ys = ["loss"]
+
+ plot(paths, args)
+
+ out_path = cfg.rel_path / "metrics.png"
+ plt.savefig(out_path, bbox_inches="tight")
\ No newline at end of file
diff --git a/image_classifier/samplers.py b/image_classifier/samplers.py
new file mode 100644
index 0000000..74ef9f0
--- /dev/null
+++ b/image_classifier/samplers.py
@@ -0,0 +1,204 @@
+import math
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+from torch import Tensor, einsum, nn
+
+# Simple filter to modify a token's probability if it shows up in the past
+# `one_time` will only apply the penalty once
+# `decay` is a factor that will exponentially apply to how far away it is
+def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
+ if factor == 1.0 or previous is None:
+ return logits
+
+ unique = set()
+ priors = reversed(previous)
+ for distance, token in enumerate(priors):
+ # skip if we're only applying the decay once
+ if one_time and token in unique:
+ continue
+
+ distance += 1
+ logits[:, token] /= factor * (distance ** decay)
+
+ # add to set if we care about it
+ if one_time:
+ unique.add(token)
+
+ return logits
+
+# Simple "filter" that modifies the logit for the stop token, based on the sequence length
+# `length` is the length of the sequence currently
+# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
+# `token` is the stop token.
+def length_penalize( logits, length, factor=0.0, token=-1 ):
+ if factor == 0.0:
+ return logits
+
+ logits[:, token] /= (length ** factor)
+ return logits
+
+# Simple way to ban tokens
+def ban_tokens( logits, tokens ):
+ for token in tokens:
+ # token not in logits
+ if logits.shape[-1] >= token:
+ continue
+ logits[:, token] = -float("inf")
+ return logits
+
+# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
+ Args:
+ logits: logits distribution shape (batch size, vocabulary size)
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+ Make sure we keep at least min_tokens per batch example in the output
+ """
+ if top_k > 0:
+ top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ if min_tokens > 1:
+ # Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
+ sorted_indices_to_remove[..., :min_tokens] = 0
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ logits[indices_to_remove] = filter_value
+
+ return logits
+
+# credit to https://github.com/LostRuins/koboldcpp/pull/464 // https://github.com/kalomaze/koboldcpp/tree/dynamic-temp
+def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.0, k = 10, sigmoidCenterPoint = 0.5 ):
+ # loop over logits[:], as the NAR will have logits.shape[0] > 1
+ for i in range(logits.shape[0]):
+ sum_exp = 0.0
+ maximum = torch.max( logits[i] )
+ for logit in logits[i]:
+ sum_exp += math.exp( logit - maximum )
+
+ prob_max_token_before_temp = 1.0 / sum_exp
+ dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
+
+ logits[i] /= dynamic_temperature
+
+ return logits
+
+
+
+# picks the top K tokens amongst a batch of logits
+# logits: [Tensor] list of logits
+# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
+def top_k_logits_list( logits_list, k ):
+ # ( batch, tokens ) => ( batch x tokens )
+ logits = torch.cat( logits_list )
+ candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
+ for i, index in enumerate(candidates):
+ t = []
+ N = np.prod(logits.size())
+ for n in logits.size():
+ N //= n
+ t.append(index // N)
+ index %= N
+ candidates[i] = tuple(t)
+ return candidates
+
+
+# Credit to: https://github.com/basusourya/mirostat/
+# performs mirostat-based sampling
+# logits: Tensor of logit probabilities
+# state: the mirostat state
+def mirostat_sample( logits, state = None ):
+ def compute_k(prob, n, tau):
+ num = 0
+ den = 0
+ for i in range(100):
+ b = prob[i]/prob[i+1]
+ t = (i+2)/(i+1)
+ num += math.log(b)*math.log(t)
+ den += math.log(t)**2
+
+ s = num/den
+ eps = s-1
+ k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s)
+ k = round(k)
+ return k
+
+ if "max_surprise" not in state:
+ state["max_surprise"] = state["tau"] * 2
+
+ if "error_surprise" not in state:
+ state["error_surprise"] = 0
+
+ if "running_total_surprise" not in state:
+ state["running_total_surprise"] = 0
+
+ sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True )
+ prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist()
+
+ k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1
+
+ sorted_logits = sorted_logits[0:k]
+ sorted_indices = sorted_indices[0:k]
+ prob_topk = torch.softmax(sorted_logits, dim = 0)
+ prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
+
+ state["index_surprise"] = math.log2(1/prob_original[prev_i])
+ state["running_total_surprise"] += state["index_surprise"]
+ state["error_surprise"] = state["index_surprise"] - state["tau"]
+ state["max_surprise"] -= state["eta"] * state["error_surprise"]
+ state["token"] = sorted_indices[prev_i]
+
+ return state
+
+# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
+# performs DRY sampling
+# * (honestly it looks close to rep pen anyways but what do I know)
+# `logits` are the scores used to sample against
+# `previous` are the prior tokens to penalize with
+# `factor` is the scalar multiplier
+# `base` is the base number to raise to the (length - allowed_length)th power
+# `allowed_length` limits the range to apply DRY to
+def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
+ if factor == 0.0 or previous is None:
+ return logits
+
+ lengths = {}
+ for i, token in enumerate( previous ):
+ length = 1
+ while length < max(allowed_length, 50):
+ j = i - length
+
+ # Start of input reached.
+ if j < 0:
+ break
+
+ # Start of match reached.
+ if previous[j] != previous[-length-1]:
+ break
+
+ length += 1
+
+ lengths[token] = max(length, lengths[token]) if token in lengths else length
+
+ for token, length in lengths.items():
+ if length < allowed_length:
+ break
+ logits[:, token] -= factor * base ** (length - allowed_length)
+
+ return logits
\ No newline at end of file
diff --git a/image_classifier/train.py b/image_classifier/train.py
index 70de575..710e02b 100755
--- a/image_classifier/train.py
+++ b/image_classifier/train.py
@@ -4,7 +4,7 @@ from .config import cfg
from .data import create_train_val_dataloader
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
-from .utils.trainer import load_engines
+from .utils.distributed import is_global_leader
import json
import logging
@@ -12,53 +12,43 @@ import random
import torch
import torch.nn.functional as F
import traceback
+import shutil
from collections import defaultdict
-from PIL import Image
+
from tqdm import tqdm
+import argparse
+from PIL import Image, ImageDraw
_logger = logging.getLogger(__name__)
+
def train_feeder(engine, batch):
- engine( image=batch["image"], text=batch["text"] )
+ with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
+ batch_size = len(batch["text"])
+ engine.current_batch_size = batch_size
- losses = engine.gather_attribute("loss")
- stat = engine.gather_attribute("stats")
+ engine( image=batch["image"], text=batch["text"] )
- loss = torch.stack([*losses.values()]).sum()
+ losses = engine.gather_attribute("loss")
+ stat = engine.gather_attribute("stats")
+
+ loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
+ engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
+
return loss, stats
@torch.inference_mode()
def run_eval(engines, eval_name, dl):
- engines_stats = {
- 'eval': eval_name
- }
-
- model = None
- names = []
- for name, engine in engines.items():
- names.append(name)
- model = engine
- break
-
-
stats = defaultdict(list)
stats['loss'] = []
- def process( name, batch, resps_list ):
- for path, ref, hyp in zip(batch["path"], batch["text"], hyp):
- continue
-
- for batch in tqdm(dl):
- batch: dict = to_device(batch, cfg.device)
-
- res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
-
+ def process( name, batch, res, loss ):
for path, ref, hyp in zip(batch["path"], batch["text"], res):
hyp = hyp.replace('', "").replace("", "")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png")
@@ -67,36 +57,74 @@ def run_eval(engines, eval_name, dl):
image = Image.open(path).convert('RGB')
image.save(hyp_path)
- losses = engine.gather_attribute("loss")
- loss = torch.stack([*losses.values()]).sum().item()
-
stats['loss'].append(loss)
+
+ processed = 0
+ while processed < cfg.evaluation.size:
+ batch = to_device(next(iter(dl)), cfg.device)
+
+ # limit to eval batch size in the event we somehow have a weird dataloader
+ for key in batch.keys():
+ batch[key] = batch[key][:cfg.evaluation.batch_size]
+
+ processed += len(batch["text"])
+
+ for name in engines:
+ engine = engines[name]
+
+ res = engine( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
+ losses = engine.gather_attribute("loss")
+ loss = torch.stack([*losses.values()]).sum().item()
+
+ process( name, batch, res, loss )
+
stats = {k: sum(v) / len(v) for k, v in stats.items()}
- engines_stats.update(flatten_dict({ name: stats }))
-
- iteration = engines.global_step
- engines_stats['it'] = iteration
- engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
+ engines_stats = {
+ f'{name}.{eval_name}': stats,
+ "it": engines.global_step,
+ }
+ #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
-def main():
+def train():
+ parser = argparse.ArgumentParser("ResNet Image Classifier")
+ parser.add_argument("--eval", action="store_true", default=None)
+ args, unknown = parser.parse_known_args()
+
+ # create log folder
setup_logging(cfg.log_dir)
+ # copy config yaml to backup
+ if cfg.yaml_path is not None and is_global_leader():
+ shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def eval_fn(engines):
+ do_gc()
+ engines.eval()
+ # wrapped in a try block because it's sometimes prone to breaking
try:
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
except Exception as e:
- print("Error occurred while performing eval:", str(e))
- print(traceback.format_exc())
+ _logger.warning(f"Error occurred while performing eval: {str(e)}")
+ _logger.warning(traceback.format_exc())
+ engines.train()
do_gc()
+ if args.eval:
+ return eval_fn(engines=trainer.load_engines())
+
+ """
+ if cfg.trainer.load_webui:
+ from .webui import start
+ start(lock=False)
+ """
+
trainer.train(
train_dl=train_dl,
train_feeder=train_feeder,
@@ -104,4 +132,5 @@ def main():
)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ # to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
+ train()
diff --git a/image_classifier/utils/__init__.py b/image_classifier/utils/__init__.py
index 96929f3..d79e335 100755
--- a/image_classifier/utils/__init__.py
+++ b/image_classifier/utils/__init__.py
@@ -7,4 +7,7 @@ from .utils import (
to_device,
tree_map,
do_gc,
+ set_seed,
+ passes_policy,
+ get_devices
)
\ No newline at end of file
diff --git a/image_classifier/utils/distributed.py b/image_classifier/utils/distributed.py
index e80b0dd..2400fdd 100755
--- a/image_classifier/utils/distributed.py
+++ b/image_classifier/utils/distributed.py
@@ -8,6 +8,10 @@ import socket
from functools import cache, wraps
from typing import Callable
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+
def get_free_port():
sock = socket.socket()
sock.bind(("", 0))
@@ -15,13 +19,18 @@ def get_free_port():
_distributed_initialized = False
-def init_distributed( fn ):
- fn()
+def init_distributed( fn, *args, **kwargs ):
+ torch.cuda.set_device(local_rank())
+ fn(*args, **kwargs)
_distributed_initialized = True
def distributed_initialized():
return _distributed_initialized
+def cleanup_distributed():
+ dist.barrier()
+ dist.destroy_process_group()
+
@cache
def fix_unset_envs():
envs = dict(
@@ -44,10 +53,12 @@ def fix_unset_envs():
def local_rank():
return int(os.getenv("LOCAL_RANK", 0))
-
def global_rank():
return int(os.getenv("RANK", 0))
+def world_size():
+ return int(os.getenv("WORLD_SIZE", 1))
+
def is_local_leader():
return local_rank() == 0
@@ -86,4 +97,7 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
if fn is None:
return wrapper
- return wrapper(fn)
\ No newline at end of file
+ return wrapper(fn)
+
+def ddp_model(model):
+ return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)
\ No newline at end of file
diff --git a/image_classifier/utils/io.py b/image_classifier/utils/io.py
new file mode 100644
index 0000000..afc2033
--- /dev/null
+++ b/image_classifier/utils/io.py
@@ -0,0 +1,88 @@
+import torch
+import json
+
+from pathlib import Path
+from safetensors import safe_open as sft_load
+from safetensors.torch import save_file as sft_save
+
+def coerce_path( path ):
+ return path if isinstance( path, Path ) else Path(path)
+
+def pick_path( path, *suffixes ):
+ suffixes = [*suffixes]
+
+ for suffix in suffixes:
+ p = path.with_suffix( suffix )
+ if p.exists():
+ return p
+
+ return path
+
+def is_dict_of( d, t ):
+ if not isinstance( d, dict ):
+ return False
+
+ return all([ isinstance(v, torch.Tensor) for v in d.values() ])
+
+# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
+def state_dict_to_tensor_metadata( data: dict, module_key=None ):
+ metadata = None
+
+ # is a state_dict, no need to coerce
+ if is_dict_of( data, torch.Tensor ):
+ return data, metadata
+
+ # is maybe a dict with a state dict + metadata, coerce it
+ metadata = {}
+ target = module_key
+ if not target:
+ for k, v in data.items():
+ # is a dict of tensors, our target
+ if is_dict_of( v, torch.Tensor ):
+ target = k
+ continue # continue to iterate to grab other metadata
+
+ # not a dict of tensors, put it as metadata
+ try:
+ metadata[k] = json.dumps(v)
+ except Exception as e:
+ pass
+
+ if not target:
+ raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
+
+ return data[target], metadata
+
+def torch_save( data, path, module_key=None ):
+ path = coerce_path(path)
+ ext = path.suffix
+
+ if ext in [".safetensor", ".sft"]:
+ data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
+
+ return sft_save( data, path, metadata )
+
+ return torch.save( data, path )
+
+def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
+ path = coerce_path(path)
+ ext = path.suffix
+
+ if ext in [".safetensor", ".sft"]:
+ state_dict = {}
+ with sft_load(path, framework=framework, device=device) as f:
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+
+ if load_metadata:
+ metadata = f.metadata()
+ for k, v in metadata.items():
+ try:
+ metadata[k] = json.loads( v )
+ except Exception as e:
+ pass
+ state_dict = { module_key: state_dict } | metadata
+
+ return state_dict
+
+ return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )
\ No newline at end of file
diff --git a/image_classifier/utils/sampler.py b/image_classifier/utils/sampler.py
old mode 100755
new mode 100644
index 5db9606..d5b6b76
--- a/image_classifier/utils/sampler.py
+++ b/image_classifier/utils/sampler.py
@@ -1,48 +1,164 @@
-"""
-A sampler that balances data by key_fns.
-
-MIT License
-
-Copyright (c) 2023 Zhe Niu
-
-niuzhe.nz@outlook.com
-"""
-
-import random
-
-
-class Sampler:
- def __init__(self, l, key_fns):
- self.tree = self._build(l, key_fns)
-
- def _build(self, l, key_fns) -> dict[dict, list]:
- if not key_fns:
- return l
-
- tree = {}
-
- key_fn, *key_fns = key_fns
-
- for x in l:
- k = key_fn(x)
-
- if k in tree:
- tree[k].append(x)
- else:
- tree[k] = [x]
-
- for k in tree:
- tree[k] = self._build(tree[k], key_fns)
-
- return tree
-
- def _sample(self, tree: dict | list):
- if isinstance(tree, list):
- ret = random.choice(tree)
- else:
- key = random.choice([*tree.keys()])
- ret = self._sample(tree[key])
- return ret
-
- def sample(self):
- return self._sample(self.tree)
\ No newline at end of file
+from dataclasses import dataclass
+from typing import Any
+import random
+
+import torch
+from torch.utils.data import Sampler
+
+from .distributed import global_rank, local_rank, world_size
+
+# Randomly picks an index from an array of indices
+class PoolSampler():
+ def __init__( self, pool = [], keep_all = False, shuffle = False ):
+ self.length = len(pool)
+ self.shuffle = shuffle
+ self.global_pool = pool if keep_all else None
+ self.global_indices = [ i for i in range(self.length) ]
+ self.reset()
+
+ def reset(self):
+ self.current_pool = [ i for i in self.global_indices ]
+ if self.shuffle:
+ random.shuffle(self.current_pool)
+
+ def sample(self, pool = None):
+ if pool is None:
+ pool = self.global_pool
+ # check if we need to reset
+ index = random.choice( self.current_pool )
+ # remove from pool
+ self.current_pool.remove(index)
+ # reset if needed
+ if len(self.current_pool) == 0:
+ self.reset()
+ # map indices to our real values
+ return pool[index] if pool is not None else index
+
+ def __len__(self):
+ return self.length # len(self.current_pool)
+
+ def __iter__(self):
+ while len(self.current_pool) > 0:
+ yield self.sample()
+
+ def __call__(self, *args, **kwargs):
+ return self.sample(*args, **kwargs)
+
+ def get_state(self):
+ return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
+
+ def set_state(self, state):
+ self.length = state["length"]
+ self.global_pool = state["global_pool"]
+ self.global_indices = state["global_indices"]
+ self.current_pool = state["current_pool"]
+
+# "Samples" through a fixed sequence from 0 to length
+# Necessary for our "shuffle+sort by duration+interleave" sampling method
+# Allows saving and loading state
+class OrderedSampler(Sampler):
+ def __init__( self, length ):
+ self.position = 0
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def __iter__(self):
+ if self.position >= self.length:
+ self.position = 0
+
+ while self.position < self.length:
+ yield self.position
+ self.position += 1
+
+ def get_state(self):
+ return { "position": self.position, "length": self.length }
+
+ def set_state(self, state):
+ self.position = state["position"]
+ self.length = state["length"]
+
+# Like the above, but will batch based on token count
+class BatchedOrderedSampler(Sampler):
+ def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
+ self.position = 0
+ self.batches = []
+ self.shuffle = shuffle
+
+ assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
+
+ current_batch = []
+ current_size = 0
+ current_index = 0
+ for key, bucket in buckets.items():
+ for path, duration in bucket:
+ # flush
+ should_flush = False
+ if max_duration > 0 and current_size + duration > max_duration:
+ should_flush = True
+ elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
+ should_flush = True
+
+ if should_flush and len(current_batch) > 0:
+ self.batches.append( current_batch )
+ current_batch = []
+ current_size = 0
+
+ current_batch.append( current_index )
+ current_index += 1
+ current_size += duration
+
+ if self.shuffle:
+ random.shuffle(self.batches)
+
+ def __len__(self):
+ return len(self.batches)
+
+ def __iter__(self):
+ if self.position >= len(self.batches):
+ self.position = 0
+ if self.shuffle:
+ random.shuffle(self.batches)
+
+ while self.position < len(self.batches):
+ yield self.batches[self.position]
+ self.position += 1
+
+ def get_state(self):
+ return { "position": self.position, "batches": self.batches }
+
+ def set_state(self, state):
+ self.position = state["position"]
+ self.batches = state["batches"]
+
+# Randomly samples indices from a given sequence from 0 to length
+# Allows saving and loading state
+class RandomSampler(Sampler):
+ def __init__( self, length ):
+ self.position = 0
+ self.length = length
+
+ self.generator = torch.Generator()
+ self.perm = torch.randperm(self.length, generator=self.generator)
+
+ def __len__(self):
+ return self.length
+
+ def __iter__(self):
+ if self.position >= self.length:
+ self.position = 0
+ self.perm = torch.randperm(self.length, generator=self.generator)
+
+ while self.position < self.length:
+ yield self.perm[self.position]
+ self.position += 1
+
+ def get_state(self):
+ return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
+
+ def set_state(self, state):
+ self.position = state["position"]
+ self.length = state["length"]
+ self.perm = state["perm"]
+ self.generator.set_state(state["generator"])
\ No newline at end of file
diff --git a/image_classifier/utils/trainer.py b/image_classifier/utils/trainer.py
index 88f6358..7c64470 100755
--- a/image_classifier/utils/trainer.py
+++ b/image_classifier/utils/trainer.py
@@ -4,12 +4,13 @@
import humanize
import json
-import os
import logging
+import numpy as np
import random
import selectors
import sys
import torch
+import os
from functools import cache
from torch.distributed import broadcast_object_list
@@ -18,9 +19,10 @@ from tqdm import tqdm
from typing import Protocol
from ..config import cfg
-from .distributed import init_distributed, distributed_initialized
from .distributed import (
- fix_unset_envs,
+ init_distributed,
+ distributed_initialized,
+ world_size,
global_leader_only,
global_rank,
is_global_leader,
@@ -28,73 +30,15 @@ from .distributed import (
local_leader_only,
)
-from ..engines import Engine, Engines, TrainFeeder, default_feeder
-from ..models import get_models
+from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
-from .utils import to_device, do_gc
+from .utils import to_device, do_gc, truncate_json
from ..utils import wrapper as ml
+from ..data import get_symmap # should decouple from this trainer script
_logger = logging.getLogger(__name__)
-_engines: Engines
_command: str
-def get_global_step():
- try:
- return _engines.global_step
- except:
- return None
-
-def get_micro_step():
- try:
- return _engines.micro_step
- except:
- return None
-
-def get_cmd():
- try:
- return _command
- except:
- raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
-
-
-get_iteration = get_global_step
-
-def load_engines():
- models = get_models(cfg.models.get())
- engines = dict()
-
- for name in models:
- model = models[name]
-
- optimizer = None
- lr_scheduler = None
-
- if cfg.hyperparameters.optimizer.lower() == "adamw":
- optimizer = ml.AdamW(
- model.parameters(),
- lr=cfg.hyperparameters.learning_rate,
- betas=(0.9, 0.96),
- eps=1e-07,
- weight_decay=0.01,
- )
-
- if cfg.trainer.load_state_dict:
- load_path = cfg.ckpt_dir / name / "fp32.pth"
- model.load_state_dict(torch.load(load_path))
-
- engines[name] = Engine(
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- )
-
- engines = Engines(engines)
- engines.setup()
-
- if not cfg.trainer.load_state_dict:
- engines.load_checkpoint()
-
- return engines
class EvalFn(Protocol):
def __call__(self, *, engines: Engines):
@@ -151,17 +95,16 @@ def _non_blocking_input():
l[0] = s
- if distributed_initialized():
+ if world_size() > 1:
broadcast_object_list(l, src=0)
_command = l[0]
return _command
-
def _make_infinite_epochs(dl):
while True:
- _logger.info("New epoch starts.")
- yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
+ #_logger.info("New epoch starts.")
+ yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
@local_leader_only(default=None)
@@ -172,30 +115,32 @@ def logger(data):
def seed(seed):
# Set up random seeds, after fork()
random.seed(seed + global_rank())
- #np.random.seed(seed + global_rank())
+ np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank())
-
def train(
train_dl: DataLoader,
train_feeder: TrainFeeder = default_feeder,
eval_fn: EvalFn = lambda x: ...,
logger: Logger = logger,
):
- fix_unset_envs()
-
engines = load_engines()
+ # validate if there's at least one model to train
+ found = False
+ for name, engine in engines.items():
+ if engine.training:
+ found = True
+ break
+ if not found:
+ raise Exception('Training, but no model loaded set to train...')
+
"""
if is_local_leader():
cfg.dump()
_logger.info(cfg)
"""
- # Setup global engines
- global _engines
- _engines = engines
-
events = []
eval_fn = global_leader_only(eval_fn)
@@ -203,15 +148,20 @@ def train(
# Pre-loop command
command = _non_blocking_input()
if command in ["eval", "eval_quit"]:
- engines.eval()
eval_fn(engines=engines)
- engines.train()
+
if command in ["quit", "eval_quit"]:
+ engines.quit()
return
last_save_step = engines.global_step
last_eval_step = 0
+ """
+ if cfg.distributed:
+ train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths)))
+ """
+
# Training loop
for batch in _make_infinite_epochs(train_dl):
if engines.global_step >= cfg.trainer.iterations:
@@ -219,17 +169,15 @@ def train(
#batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(batch=batch, feeder=train_feeder)
-
- iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
- stats['it'] = iteration
- stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
-
- del stats['batch_size']
- del stats['wall_time']
- del stats['global_step']
+ stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
elapsed_time = stats.get("elapsed_time", 0)
- _logger.info(f"Training Metrics: {json.dumps(stats)}.")
+ try:
+ metrics = json.dumps(stats)
+ except Exception as e:
+ metrics = str(stats)
+
+ _logger.info(f"Training Metrics: {truncate_json(metrics)}.")
command = _non_blocking_input()
@@ -267,29 +215,48 @@ def train(
if "lr" in command:
rate = float(command.split(" ")[-1])
- engines.set_lr(rate)
- print("Updating LR to:", rate)
+ try:
+ engines.set_lr(rate)
+ _logger.info(f"Updating LR to: {rate}")
+ except Exception as e:
+ _logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
+
+ if "export" in command:
+ train_dl.dataset.save_state_dict()
+ engines.save_checkpoint()
+ last_save_step = engines.global_step
+
+ if is_global_leader():
+ engines.export(userdata={"symmap": get_symmap()})
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"]
+ export_commands = ["export"]
if cfg.trainer.save_on_quit:
saving_commands.append("quit")
+ if cfg.trainer.export_on_quit:
+ export_commands.append("quit")
+
+ if cfg.trainer.export_on_save:
+ export_commands.append("save")
+
if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
+ train_dl.dataset.save_state_dict()
engines.save_checkpoint()
last_save_step = engines.global_step
+
+ if command in export_commands and is_global_leader():
+ engines.export(userdata={"symmap": get_symmap()})
if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
- do_gc()
-
- engines.eval()
- eval_fn(engines=engines)
- engines.train()
last_eval_step = engines.global_step
+ eval_fn(engines=engines)
if command in ["quit"]:
- return
\ No newline at end of file
+ engines.quit()
+ return
diff --git a/image_classifier/utils/utils.py b/image_classifier/utils/utils.py
index 988f595..a253848 100755
--- a/image_classifier/utils/utils.py
+++ b/image_classifier/utils/utils.py
@@ -7,8 +7,16 @@ from .distributed import global_rank, local_rank, global_leader_only
import gc
import logging
import pandas as pd
+import numpy as np
import re
import torch
+import random
+import time
+import psutil
+import math
+import logging
+
+_logger = logging.getLogger(__name__)
from coloredlogs import ColoredFormatter
from logging import StreamHandler
@@ -16,9 +24,16 @@ from pathlib import Path
from torch import Tensor, nn
from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload
-
+from contextlib import contextmanager
T = TypeVar("T")
+def truncate_json( str ):
+
+ def fun( match ):
+ return "{:.4f}".format(float(match.group()))
+
+ return re.sub(r"\d+\.\d{8,}", fun, str)
+
def do_gc():
gc.collect()
torch.cuda.empty_cache()
@@ -28,6 +43,14 @@ def flatten_dict(d):
return records[0] if records else {}
+def set_seed(seed=None):
+ if not seed:
+ seed = int(time.time())
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
def _get_named_modules(module, attrname):
for name, module in module.named_modules():
if hasattr(module, attrname):
@@ -155,5 +178,363 @@ def tree_map(fn: Callable, x):
return x
-def to_device(x: T, device) -> T:
- return tree_map(lambda t: t.to(device), x)
+def to_device(x: T | None, *args, **kwargs) -> T:
+ if x is None:
+ return
+
+ return tree_map(lambda t: t.to(*args, **kwargs), x)
+
+def coalese( *arg, return_last=True ):
+ return [ x for x in arg if x is not None ][-1 if return_last else 0]
+
+# checks if a module name is within a given whitelist/blacklist policy dict
+def passes_policy( policy, name ):
+ if policy is None:
+ return True
+
+ if "exclude" in policy:
+ for term in policy["exclude"]:
+ if term in name:
+ return False
+
+ if "include" in policy:
+ for term in policy["include"]:
+ if term in name:
+ return True
+
+ return False
+
+# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
+@contextmanager
+def autocast(input, from_dtype, to_dtype):
+ if input.dtype == from_dtype:
+ input = input.to(to_dtype)
+ yield input
+ input = input.to(from_dtype)
+ else:
+ yield input
+
+@contextmanager
+def autocasts(input, from_dtype, to_dtype):
+ if input.dtype in from_dtype:
+ from_dtype = input.dtype
+ input = input.to(to_dtype)
+ yield input
+ input = input.to(from_dtype)
+ else:
+ yield input
+
+# handles temporarily upcasting 'index tensors' so torch will stop bitching
+def autocast_forward( func ):
+ def wrapper( self, input, *args, **kwargs ):
+ with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
+ return func( self, k, *args, **kwargs )
+ return wrapper
+
+# handles migrating an input tensor to a given devicve
+def auto_align_inputs_forward( module, device=None, name = None ):
+ func = module.forward
+
+ if device is None:
+ if hasattr( module, 'device' ):
+ device = module.device
+ else:
+ try:
+ device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
+ except Exception as e:
+ return func
+
+
+ def wrapper( *args, **kwargs ):
+ args = [*args]
+ # search through args and kwargs for any Tensor arguments
+ for i, arg in enumerate(args):
+ if not isinstance( arg, torch.Tensor ):
+ continue
+ args[i] = arg.to( device=device )
+
+ for k, v in kwargs.items():
+ if not isinstance( v, torch.Tensor ):
+ continue
+ kwargs[k] = v.to( device=device )
+
+ # disgusting patch
+ if "position_embeddings" in kwargs:
+ kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
+
+ return func( *args, **kwargs )
+ return wrapper
+
+# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
+# generalizing this would be super sugoi but the there's no catch all for arguments
+def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
+ bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
+
+ device = next(model.parameters()).device
+ dtype = next(model.parameters()).dtype
+ modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
+
+ for *parent, k in modules:
+ name = '.'.join(parent)
+
+ m = getattr( model.get_submodule(name), k )
+
+ if isinstance(m, klass):
+ continue
+
+ kwargs = dict(
+ in_features = m.in_features,
+ out_features = m.out_features,
+ bias = m.bias is not None,
+ ) if not bnb else dict(
+ input_features=m.in_features,
+ output_features=m.out_features,
+ bias=m.bias is not None,
+ )
+
+ # overwrite
+ setattr(
+ model.get_submodule(name), k,
+ klass( **kwargs ).to(device=device, dtype=dtype)
+ )
+
+ if verbose:
+ _logger.info(f"Replacing {name}.{k} to: {klass}")
+
+ return model
+
+def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
+ device = next(model.parameters()).device
+ dtype = next(model.parameters()).dtype
+ modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
+
+ for *parent, k in modules:
+ name = '.'.join(parent)
+
+ m = getattr( model.get_submodule(name), k )
+
+ if isinstance(m, klass):
+ continue
+
+ kwargs = dict(
+ num_embeddings=m.num_embeddings,
+ embedding_dim=m.embedding_dim,
+ padding_idx=m.padding_idx,
+ max_norm=m.max_norm,
+ norm_type=m.norm_type,
+ scale_grad_by_freq=m.scale_grad_by_freq,
+ sparse=m.sparse,
+ )
+
+ # overwrite
+ setattr(
+ model.get_submodule(name), k,
+ klass( **kwargs ).to(device=device, dtype=dtype)
+ )
+
+ if verbose:
+ _logger.info(f"Replacing {name}.{k} to: {klass}")
+
+ return model
+
+# cannot feasibly do default arguments here sad
+def replace_attention( model, klass, target, mode="math", verbose=False ):
+ device = next(model.parameters()).device
+ dtype = next(model.parameters()).dtype
+ modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
+
+ for *parent, k in modules:
+ name = '.'.join(parent)
+
+ m = getattr( model.get_submodule(name), k )
+
+ if isinstance(m, klass):
+ continue
+
+ kwargs = dict(
+ config = m.config,
+ layer_idx = m.layer_idx,
+ mode = mode,
+ )
+ # overwrite
+ setattr(
+ model.get_submodule(name), k,
+ klass( **kwargs ).to(device=device, dtype=dtype)
+ )
+
+ if verbose:
+ _logger.info(f"Replacing {name}.{k} to: {klass}")
+
+ return model
+
+# trim/expand a tensor (for example, in a state dict)
+def resize_weight( weight, target, dim=0, random=True ):
+ # trim
+ if target < weight.shape[dim]:
+ return weight[:target]
+ # expand
+ if target > weight.shape[dim]:
+ fn = torch.rand if random else torch.zeros
+ return torch.stack(
+ [ x for x in weight ] +
+ [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
+ )
+
+ return weight
+
+def get_devices():
+ return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu']
+
+# grabs the memory properties of a given device
+def get_device_properties( device ):
+ if 'cuda' in device:
+ props = torch.cuda.get_device_properties(device)
+ free, total = torch.cuda.mem_get_info(device)
+ else:
+ props = psutil.virtual_memory()
+ free, total = props.available, props.total
+
+ return {"name": device, "total": total, "free": free, "props": props}
+
+# gets the rough size for a given module's parameters
+def get_module_size( module ):
+ param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
+ buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
+ return param_size + buffer_size
+
+# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way
+# it'd be better to just attach to layers itself rather than every single module
+
+# assigns modules to requested devices for a given policy
+def get_model_offload_policy(module, policy=None):
+ # handle any other weird values this is set to
+ if not isinstance(policy, dict):
+ policy = {}
+
+ # default to only include the core model, and not the other modules (embeddings) in the splitting policy
+ if "include" not in policy:
+ policy["include"] = ["model"]
+
+ if "limits" not in policy:
+ policy["limits"] = []
+
+ if "assign" not in policy:
+ policy["assign"] = []
+
+ if "devices" not in policy:
+ policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget
+
+ # create initial device info
+ devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
+ modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
+ # filter
+ modules = [ (name, size) for name, size in modules if name and size ]
+
+ total_size = sum([size for name, size in modules])
+
+ # set caps if requested in the policy
+ for i, cap in enumerate(policy["limits"]):
+ # no limit, skip
+ if cap <= 0:
+ continue
+ # is fractional, scale to total size
+ if cap < 1:
+ cap = math.floor(total_size * cap)
+ # available space is below cap, don't set
+ if devices[i]["free"] < cap:
+ continue
+ # cap to requested size
+ devices[i]["free"] = cap
+
+ # assign if specific parts of the model are requested for assignment
+ if policy["assign"]:
+ discarded = []
+ # yuck, there has to be a better way
+ for device_index, includes in enumerate( policy["assign"] ):
+ device = devices[device_index]
+
+ buffered_modules = []
+ buffered_size = device["free"]
+
+ # iterate through list of modules to compare against includes
+ for name, size in modules:
+ # doesn't pass policy
+ if not passes_policy( {"include": includes}, name ):
+ continue
+ # check if within budget
+ if buffered_size - size >= 0:
+ # add to buffer
+ buffered_modules.append( (name, size) )
+ buffered_size -= size
+ # budget exceeded, flush buffer
+ else:
+ discarded += buffered_modules
+ buffered_modules = []
+ buffered_size = 0
+ break
+
+ if buffered_modules and buffered_size:
+ device["modules"] += [ name for name, size in buffered_modules ]
+ device["free"] = buffered_size
+
+ modules = discarded
+
+ device_index = 0
+ module_index = 0
+ # assign modules to each device
+ while module_index < len(modules) and device_index < len(devices):
+ device = devices[device_index]
+ name, size = modules[module_index]
+
+ # fits within budget
+ if device["free"] - size >= 0:
+ device["modules"].append( name )
+ device["free"] -= size
+ module_index += 1
+ # does not fit in budget, increase device index
+ else:
+ device_index += 1
+ _logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
+
+ # to-do: check that all modules are exhausted
+ assert module_index >= len(modules)
+
+ # only return devices with modules assigned
+ return [ device for device in devices if device["modules"] ]
+
+# handles naively splitting a model's layers across multiple devices
+# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
+def offload_model( model, policy=None ):
+ policy = get_model_offload_policy(model, policy=policy)
+
+ # move modules to respective devices
+ for i, device in enumerate( policy ):
+ # nothing assigned, skip
+ if not device["modules"]:
+ continue
+
+ for name in device["modules"]:
+ module = model.get_submodule(name)
+ module = module.to( device["name"] )
+ module.device = device['name']
+
+ # wrap modules with forward to ensure all inputs are matched to its device
+ for name, module in model.named_modules():
+ if not hasattr( module, 'forward' ):
+ continue
+
+ module.forward = auto_align_inputs_forward(module)
+
+ """
+ # Validate that the layers are all in the right spot
+ for name, module in model.named_modules():
+ if not not [*module.named_children()]:
+ continue
+ try:
+ _logger.info( name, next(module.parameters()).device )
+ except Exception as e:
+ _logger.info( name, "?" )
+ pass
+ """
+
+ return model
\ No newline at end of file
diff --git a/image_classifier/utils/wrapper.py b/image_classifier/utils/wrapper.py
index 040762d..2cbd14b 100755
--- a/image_classifier/utils/wrapper.py
+++ b/image_classifier/utils/wrapper.py
@@ -1,20 +1,39 @@
from contextlib import contextmanager
+import math
import torch
import torch.nn.functional as F
+import logging
+
from ..config import cfg
+_logger = logging.getLogger(__name__)
+
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
-if cfg.bitsandbytes.enabled:
- import bitsandbytes as bnb
-
- if cfg.bitsandbytes.linear:
- Linear = bnb.nn.Linear8bitLt
+Adam = torch.optim.Adam
+AdamW = torch.optim.AdamW
+SGD = torch.optim.SGD
+Adagrad = torch.optim.Adagrad
- if cfg.bitsandbytes.embedding:
- Embedding = bnb.nn.StableEmbedding
+# https://github.com/kyegomez/BitNet
+if cfg.optimizations.bitnet:
+ from bitnet import BitLinear
+
+if cfg.optimizations.bitsandbytes:
+ import bitsandbytes as bnb
+
+ if cfg.optimizations.linear:
+
+ if cfg.optimizations.bitnet:
+ Linear = BitLinear
+ else:
+ Linear = bnb.nn.Linear8bitLt
+
+ if cfg.optimizations.embedding:
+ Embedding = bnb.nn.modules.Embedding
+ """
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input,
self.weight,
@@ -24,52 +43,101 @@ if cfg.bitsandbytes.enabled:
self.scale_grad_by_freq,
self.sparse,
)).to(self.weight.dtype) )
-
-Adam = torch.optim.Adam
-AdamW = torch.optim.AdamW
-
-if cfg.bitsandbytes.enabled:
- import bitsandbytes as bnb
-
- Adam = bnb.optim.Adam
- AdamW = bnb.optim.AdamW
-
-# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
-@contextmanager
-def autocast(input, from_dtype, to_dtype):
- if input.dtype == from_dtype:
- input = input.to(to_dtype)
- yield input
- input = input.to(from_dtype)
- else:
- yield input
-
-@contextmanager
-def autocasts(input, from_dtype, to_dtype):
- if input.dtype in from_dtype:
- from_dtype = input.dtype
- input = input.to(to_dtype)
- yield input
- input = input.to(from_dtype)
- else:
- yield input
-
-# handles temporarily upcasting 'index tensors' so torch will stop bitching
-def autocast_forward( func ):
- def wrapper( self, input, *args, **kwargs ):
- with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
- return func( self, k, *args, **kwargs )
"""
- if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
- return func( self, input.to(torch.int32), *args, **kwargs )
- return func( self, input, *args, **kwargs )
- """
- return wrapper
-Embedding.forward = autocast_forward(Embedding.forward)
-if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
- torch.nn.Linear = Linear
- torch.nn.Embedding = Embedding
+ if cfg.optimizations.optimizers:
+ Adam = bnb.optim.Adam8bit
+ AdamW = bnb.optim.AdamW8bit
+ SGD = bnb.optim.SGD8bit
+ Adagrad = bnb.optim.Adagrad8bit
- torch.optim.Adam = Adam
- torch.optim.AdamW = AdamW
\ No newline at end of file
+elif cfg.optimizations.dadaptation:
+ import dadaptation
+
+ if cfg.optimizations.optimizers:
+ Adam = dadaptation.DAdaptAdam
+ AdamW = dadaptation.DAdaptAdam
+ SGD = dadaptation.DAdaptSGD
+ AdaGrad = dadaptation.DAdaptAdaGrad
+
+if cfg.optimizations.fp8:
+ import transformer_engine.pytorch as te
+
+ Linear = te.Linear
+
+ @contextmanager
+ def autocast():
+ yield te.fp8_autocast(enabled=True)
+else:
+ @contextmanager
+ def autocast():
+ yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
+
+if cfg.optimizations.injects:
+ if cfg.optimizations.linear:
+ torch.nn.Linear = Linear
+
+ if cfg.optimizations.embedding:
+ torch.nn.Embedding = Embedding
+
+ if cfg.optimizations.optimizers:
+ torch.optim.Adam = Adam
+ torch.optim.AdamW = AdamW
+ torch.optim.SGD = SGD
+
+AVAILABLE_COMPILE_BACKENDS = []
+
+try:
+ AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
+except Exception as e:
+ pass
+
+
+if cfg.optimizations.tensorrt:
+ try:
+ import torch_tensorrt
+ AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
+ except Exception as e:
+ _logger.warning(f'Error while importing TensorRT: {str(e)}')
+ pass
+
+def compile_model(model, backend="auto"):
+ if not backend or backend == "auto":
+ backend = AVAILABLE_COMPILE_BACKENDS[0]
+
+ if backend not in AVAILABLE_COMPILE_BACKENDS:
+ return torch.compile(model)
+
+ return torch.compile(model, backend=backend)
+
+# https://github.com/konstmish/prodigy
+try:
+ from prodigyopt import Prodigy
+except Exception as e:
+ _logger.warning(f'Error while importing Prodigyopt: {str(e)}')
+ pass
+
+# https://github.com/facebookresearch/schedule_free/
+try:
+ import schedulefree
+except Exception as e:
+ _logger.warning(f'Error while importing Schedule_Free: {str(e)}')
+ pass
+
+# backwards compat
+from .utils import (
+ autocast_forward,
+ replace_linear as replace_linear_old,
+ replace_embedding as replace_embedding_old,
+ replace_attention,
+ resize_weight,
+ offload_model,
+)
+
+# wrapped here so we can maintain default args
+def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
+ return replace_linear_old( model, klass, target, verbose )
+def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
+ return replace_embedding_old( model, klass, target, verbose )
+
+Embedding.forward = autocast_forward(Embedding.forward)
\ No newline at end of file
diff --git a/image_classifier/webui.py b/image_classifier/webui.py
new file mode 100644
index 0000000..4e0ce15
--- /dev/null
+++ b/image_classifier/webui.py
@@ -0,0 +1,220 @@
+import os
+import re
+import argparse
+import random
+import tempfile
+import functools
+from datetime import datetime
+
+import gradio as gr
+
+from time import perf_counter
+from pathlib import Path
+from PIL import Image
+
+from .inference import Classifier, cfg
+from .train import train
+from .utils import get_devices
+
+classifier = None
+
+layout = {}
+layout["inference"] = {}
+layout["training"] = {}
+layout["settings"] = {}
+
+for k in layout.keys():
+ layout[k]["inputs"] = { "progress": None }
+ layout[k]["outputs"] = {}
+ layout[k]["buttons"] = {}
+
+# there's got to be a better way to go about this
+def gradio_wrapper(inputs):
+ def decorated(fun):
+ @functools.wraps(fun)
+ def wrapped_function(*args, **kwargs):
+ for i, key in enumerate(inputs):
+ kwargs[key] = args[i]
+ try:
+ return fun(**kwargs)
+ except Exception as e:
+ raise gr.Error(str(e))
+ return wrapped_function
+ return decorated
+
+class timer:
+ def __init__(self, msg="Elapsed time:"):
+ self.msg = msg
+
+ def __enter__(self):
+ self.start = perf_counter()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
+
+ gr.Info(msg)
+ print(f'[{datetime.now().isoformat()}] {msg}')
+
+# returns a list of models, assuming the models are placed under ./training/ or ./models/
+def get_model_paths( paths=[Path("./data/"), Path("./training/"), Path("./models/")] ):
+ yamls = []
+
+ for path in paths:
+ if not path.exists():
+ continue
+
+ for yaml in path.glob("**/*.yaml"):
+ if "/logs/" in str(yaml):
+ continue
+
+ yamls.append( yaml )
+
+ return yamls
+
+def get_dtypes():
+ return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
+
+
+#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
+def load_model( yaml, device, dtype ):
+ gr.Info(f"Loading: {yaml}")
+ try:
+ init_classifier( yaml=Path(yaml), restart=True, device=device, dtype=dtype )
+ except Exception as e:
+ raise gr.Error(e)
+ gr.Info(f"Loaded model")
+
+def init_classifier(yaml=None, restart=False, device="cuda", dtype="auto"):
+ global classifier
+
+ if classifier is not None:
+ if not restart:
+ return classifier
+
+ del classifier
+ classifier = None
+
+ parser = argparse.ArgumentParser(allow_abbrev=False)
+ parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
+ parser.add_argument("--device", type=str, default=device)
+ parser.add_argument("--amp", action="store_true")
+ parser.add_argument("--dtype", type=str, default=dtype)
+ args, unknown = parser.parse_known_args()
+
+ classifier = Classifier( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
+ return classifier
+
+@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
+def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
+ if not cfg.yaml_path:
+ raise Exception("No YAML loaded.")
+
+ parser = argparse.ArgumentParser(allow_abbrev=False)
+ # I'm very sure I can procedurally generate this list
+ parser.add_argument("--image", type=str, default=kwargs["image"])
+ parser.add_argument("--temp", type=float, default=kwargs["temp"])
+ args, unknown = parser.parse_known_args()
+
+ classifier = init_classifier()
+
+ args.image = Image.open(args.image).convert('RGB')
+
+ gr.Info("Inferencing...")
+ with timer("Inferenced in") as t:
+ answer = classifier.inference(
+ image=args.image,
+ temperature=args.temp,
+ )
+
+ return answer
+
+# setup args
+parser = argparse.ArgumentParser(allow_abbrev=False)
+parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
+parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
+parser.add_argument("--share", action="store_true")
+parser.add_argument("--render_markdown", action="store_true", default="CLASSIFIER_YAML" in os.environ)
+args, unknown = parser.parse_known_args()
+
+args.listen_host = None
+args.listen_port = None
+args.listen_path = None
+if args.listen:
+ try:
+ match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
+
+ args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
+ args.listen_port = match[1] if match[1] != "" else None
+ args.listen_path = match[2] if match[2] != "" else "/"
+ except Exception as e:
+ pass
+
+if args.listen_port is not None:
+ args.listen_port = int(args.listen_port)
+ if args.listen_port == 0:
+ args.listen_port = None
+
+# setup gradio
+ui = gr.Blocks()
+with ui:
+ with gr.Tab("Inference"):
+ with gr.Row():
+ with gr.Column(scale=4):
+ layout["inference"]["inputs"]["image"] = gr.Image(label="Input Image", sources=["upload"], type="filepath")
+ layout["inference"]["outputs"]["output"] = gr.Textbox(label="Output")
+ with gr.Column(scale=4):
+ with gr.Row():
+ layout["inference"]["inputs"]["temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature", info="Modifies the randomness from the samples. (0 to greedy sample)")
+ layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
+
+ layout["inference"]["buttons"]["inference"].click(
+ fn=do_inference,
+ inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
+ outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
+ )
+
+ """
+ with gr.Tab("Training"):
+ with gr.Row():
+ with gr.Column(scale=1):
+ layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
+ with gr.Row():
+ with gr.Column(scale=1):
+ layout["training"]["buttons"]["train"] = gr.Button(value="Train")
+
+ layout["training"]["buttons"]["train"].click(
+ fn=do_training,
+ outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
+ )
+ """
+
+ with gr.Tab("Settings"):
+ with gr.Row():
+ with gr.Column(scale=7):
+ with gr.Row():
+ layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
+ layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
+ layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
+ with gr.Column(scale=1):
+ layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
+
+ layout["settings"]["buttons"]["load"].click(
+ fn=load_model,
+ inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
+ outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
+ )
+
+ if os.path.exists("README.md") and args.render_markdown:
+ md = open("README.md", "r", encoding="utf-8").read()
+ # remove HF's metadata
+ if md.startswith("---\n"):
+ md = "".join(md.split("---")[2:])
+ gr.Markdown(md)
+
+def start( lock=True ):
+ ui.queue(max_size=8)
+ ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
+
+if __name__ == "__main__":
+ start()
\ No newline at end of file
diff --git a/scripts/run.sh b/scripts/run.sh
old mode 100755
new mode 100644
diff --git a/setup.py b/setup.py
index c719379..4c55a3b 100755
--- a/setup.py
+++ b/setup.py
@@ -1,5 +1,5 @@
import subprocess
-
+import sys
from pathlib import Path
from datetime import datetime
from setuptools import setup, find_packages
@@ -8,7 +8,6 @@ def shell(*args):
out = subprocess.check_output(args)
return out.decode("ascii").strip()
-
def write_version(version_core, pre_release=True):
if pre_release:
time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
@@ -23,8 +22,7 @@ def write_version(version_core, pre_release=True):
return version
-
-with open("README.md", "r", encoding="utf-8") as f:
+with open("README.md", "r") as f:
long_description = f.read()
setup(
@@ -37,17 +35,37 @@ setup(
long_description=long_description,
long_description_content_type="text/markdown",
packages=find_packages(),
- install_requires=[
+ install_requires=(
+ # training backends
+ ["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else [])
+ + [
+ # logging niceties
"coloredlogs>=15.0.1",
+ "humanize>=4.4.0",
+ "matplotlib>=3.6.0",
+ "pandas>=1.5.0",
+
+ # boiler plate niceties
"diskcache>=5.4.0",
"einops>=0.6.0",
- "omegaconf==2.0.6",
- "tqdm>=4.64.1",
- "humanize>=4.4.0",
+ "tqdm",
- "pandas>=1.5.0",
+ # HF bloat
+ "tokenizers",
+ "transformers",
+ "safetensors",
+
+ # training bloat
+ "h5py",
+ "prodigyopt @ git+https://github.com/konstmish/prodigy",
+
+ # practically the reason to use python
+ "numpy",
"torch>=1.13.0",
- "torchmetrics",
+ "torchmetrics",
+
+ "simple_http_server",
+ "pillow"
],
url="https://git.ecker.tech/mrq/resnet-classifier",
)