diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 diff --git a/LICENSE b/LICENSE old mode 100755 new mode 100644 diff --git a/README.md b/README.md old mode 100755 new mode 100644 index 347e629..a744721 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # Tentative Title For A ResNet-Based Image Classifier -This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/). +This is a simple ResNet based image classifier for images, using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/). ## Premise -This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born. +This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I faced. This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`. @@ -16,44 +16,14 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc 3. Install using `pip3 install -e ./image_classifier/`. -4. Train using `python3 -m image_classifier.train yaml='./data/config.yaml'`. +4. Train using `python3 -m image_classifier.train --yaml='./data/config.yaml'`. 5. Wait. ## Inferencing -Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0` +Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" --yaml="./data/config.yaml"` ### Continuous Usage -If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label. - -## Known Issues - -* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed. -* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed. -* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment. - -## Strawmen - ->\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!! - -I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests. - ->\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]` - -Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader). - -Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there. - ->\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`?? - -I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment. - ->\> UGH!!! What are you talking about """specific images"""??? - -[ひみつ](https://files.catbox.moe/csuh49.webm) - ->\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`??? - -:) \ No newline at end of file +If you're looking to continuously classify images, use `python3 -m image_classifier --listen --port=7860 --yaml="./data/config.yaml"` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label. \ No newline at end of file diff --git a/data/config.yaml b/data/config.yaml old mode 100755 new mode 100644 index 5e51278..3bebbe0 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,85 +1,84 @@ -dataset: - training: [ - "./data/images/" - ] - - validation: [] - - use_hdf5: False - - workers: 0 - cache: True +weights_format: sft models: - _models: - - name: "classifier" - tokens: 0 - len: 6 +- name: "classifier" + tokens: 0 + len: 6 + dim: 512 + resnet: 34 +#loras: +#- name : "lora" +# rank: 128 +# alpha: 128 +# training: True +# rvq_levels: [] hyperparameters: batch_size: 256 - gradient_accumulation_steps: 64 - gradient_clipping: 100 + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + warmup_steps: 10 + + optimizer: Prodigy + learning_rate: 1.0 + torch_optimizer: True - optimizer: Adamw - learning_rate: 1.0e-3 - - scheduler_type: "" - #scheduler_type: OneCycle - #scheduler_params: - # cycle_first_step_size: 10_000 - # cycle_first_stair_count: 10_000 - - # cycle_second_step_size: 15_000 - # cycle_second_stair_count: 15_000 - - # decay_step_size: 5_000 - - # cycle_min_lr: 2.5e-4 # 1.0e-5 - # cycle_max_lr: 2.5e-4 # 1.0e-4 - # decay_lr_rate: 0.0 - - # cycle_min_mom: 0.90 - # cycle_max_mom: 0.99 - # decay_mom_rate: 0.0 + scheduler: "" # ScheduleFree + torch_scheduler: True evaluation: - batch_size: 32 - frequency: 250 - size: 32 + batch_size: 64 + frequency: 100 + size: 64 - steps: 300 - temperature: 1.0 + steps: 450 + temperature: 0.0 trainer: - iterations: 100_000 - - save_tag: step - save_on_oom: True - save_on_quit: True + iterations: 1_000_000 save_frequency: 100 - - aggressive_optimizations: False - - check_for_oom: False - - #load_tag: "9500" - #load_state_dict: True - #load_states: False - #strict_loading: False - #restart_step_count: True + keep_last_checkpoints: 32 - gc_mode: None # "global_step" + check_for_oom: False + gradient_checkpointing: True - weight_dtype: float32 + weight_dtype: bfloat16 + amp: True - backend: local + backend: deepspeed deepspeed: - zero_optimization_level: 0 - use_compression_training: True + inferencing: False + amp: False inference: - use_vocos: True + backend: local -bitsandbytes: - enabled: false \ No newline at end of file + weight_dtype: bfloat16 + amp: True + +optimizations: + injects: False + replace: True + + linear: False + embedding: False + optimizers: True + + bitsandbytes: False + dadaptation: False + bitnet: False + fp8: False + +dataset: + use_hdf5: True + hdf5_flag: r + + workers: 1 + cache: True + + training: [ + "./data/images/" + ] + validation: [ + "./data/validation/" + ] \ No newline at end of file diff --git a/image_classifier/__main__.py b/image_classifier/__main__.py index d9907fe..3916455 100755 --- a/image_classifier/__main__.py +++ b/image_classifier/__main__.py @@ -12,35 +12,55 @@ def main(): parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--listen", action='store_true') parser.add_argument("--port", type=int, default=9090) + parser.add_argument("--yaml", type=Path, default=None) - parser.add_argument("--ckpt", type=Path, default=None) - parser.add_argument("--temp", type=float, default=1.0) - parser.add_argument("--device", default="cuda") + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--amp", action="store_true") + parser.add_argument("--dtype", type=str, default=None) + + parser.add_argument("--temp", type=float, default=0.0) + args, unknown = parser.parse_known_args() - classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device ) + classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp ) if args.listen: @route("/") - def inference( b64, temperature=1.0 ): + def inference( b64, temperature=args.temp ): image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") - return { "answer": classifier.inference( image=image, temperature=args.temp ) } + return { "answer": classifier.inference( image=image, temperature=temperature ) } server.start(port=args.port) else: parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--path", type=Path) parser.add_argument("--base64", type=str) + parser.add_argument("--write", type=Path) parser.add_argument("--temp", type=float, default=1.0) - args, unknown = parser.parse_known_args() + args, unknown = parser.parse_known_args() + images = [] if args.path: - image = Image.open(args.path).convert('RGB') + if args.path.is_dir(): + for p in args.path.rglob("./*.jpg"): + image = Image.open(p).convert('RGB') + images.append(image) + for p in args.path.rglob("./*.png"): + image = Image.open(p).convert('RGB') + images.append(image) + else: + image = Image.open(args.path).convert('RGB') + images.append(image) elif args.base64: image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB") + images.append(image) else: raise "Specify a --path or --base64." - answer = classifier.inference( image=image, temperature=args.temp ) - print("Answer:", answer) + for image in images: + answer = classifier.inference( image=image, temperature=args.temp ) + print("Answer:", answer) + if args.write: + args.write.mkdir(exist_ok=True) + image.save( args.write / f"{answer}.jpg") if __name__ == "__main__": main() diff --git a/image_classifier/config.py b/image_classifier/config.py index e39a22e..8adfef8 100755 --- a/image_classifier/config.py +++ b/image_classifier/config.py @@ -6,31 +6,61 @@ import os import subprocess import sys import time - -from dataclasses import asdict, dataclass -from dataclasses import dataclass, field - -from functools import cached_property, cache -from pathlib import Path -from omegaconf import OmegaConf +import argparse +import yaml +import random +import logging import torch +import numpy as np + +from dataclasses import asdict, dataclass, field + +from functools import cached_property +from pathlib import Path + +from .utils.distributed import world_size + + +def set_seed(seed=None): + if not seed: + seed = time.time() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) @dataclass() -class _Config: - cfg_path: str | None = None +class BaseConfig: + yaml_path: str | None = None # path passed in through --yaml @property - def relpath(self): + def cfg_path(self): + return Path(self.yaml_path.parent) if self.yaml_path is not None else None + + @property + def rel_path(self): return Path(self.cfg_path) + @property + def cache_dir(self): + return self.rel_path / ".cache" + + @property + def data_dir(self): + return self.rel_path / "data" + + @property + def metadata_dir(self): + return self.rel_path / "metadata" + @property def ckpt_dir(self): - return self.relpath / "ckpt" + return self.rel_path / "ckpt" @property def log_dir(self): - return self.relpath / "logs" / str(self.start_time) + return self.rel_path / "logs" / str(self.start_time) @cached_property def start_time(self): @@ -64,39 +94,28 @@ class _Config: with open(path, "w") as f: f.write(self.dumps()) - @staticmethod - def _is_cfg_argv(s): - return "=" in s and "--" not in s - @classmethod def from_yaml( cls, yaml_path ): - return cls.from_cli( [f'yaml="{yaml_path}"'] ) + state = {} + state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8")) + state.setdefault("yaml_path", yaml_path) + return cls(**state) @classmethod def from_cli(cls, args=sys.argv): - cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)]) + # legacy support for yaml=`` format + for i, arg in enumerate(args): + if arg.startswith("yaml"): + args[i] = f'--{arg}' - # Replace argv to ensure there are no omegaconf options, for compatibility with argparse. - sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)] + parser = argparse.ArgumentParser(allow_abbrev=False) + parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too + args, unknown = parser.parse_known_args(args=args) - if cli_cfg.get("help"): - print(f"Configurable hyperparameters with their default values:") - print(json.dumps(asdict(cls()), indent=2, default=str)) - exit() + if args.yaml: + return cls.from_yaml( args.yaml ) - if "yaml" in cli_cfg: - yaml_cfg = OmegaConf.load(cli_cfg.yaml) - yaml_path = Path(cli_cfg.yaml).absolute() - cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1]) - cfg_path = cfg_path.with_suffix("") - cfg_path = f'./{cfg_path}' - - yaml_cfg.setdefault("cfg_path", cfg_path) - cli_cfg.pop("yaml") - else: - yaml_cfg = {} - merged = OmegaConf.merge(yaml_cfg, cli_cfg) - return cls(**dict(merged)) + return cls(**{}) def __repr__(self): return str(self) @@ -106,104 +125,195 @@ class _Config: @dataclass() class Dataset: - training: list[Path] = field(default_factory=lambda: []) - validation: list[Path] = field(default_factory=lambda: []) - - temp: list[Path] = field(default_factory=lambda: []) - - # de-implemented, because the data isn't that large to facilitate HDF5 - hdf5_name: str = "data.h5" - use_hdf5: bool = False + training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset + validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset - workers: int = 8 - cache: bool = True + hdf5_name: str = "data.h5" # file name to load the HDF5 dataset + use_hdf5: bool = False # whether to load from an HDF5 dataset + hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways + + validate: bool = True # validate each utterance on wheter it can be included based on duration range caps + workers: int = 8 # number of dataloader workers to spawn + cache: bool = True # use diskcache to cache the dataset +# I really need to clean this up @dataclass() class Model: - name: str = "" + name: str = "classifier" tokens: int = 0 # number of token types len: int = 1 # how long a sequence can be dim: int = 512 + resnet: int = 18 + + width: int = 300 + height: int = 80 + + version: int = 1 + training: bool = True + frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training @property def full_name(self): return self.name -@dataclass() -class Models: - _models: list[Model] = field(default_factory=lambda: [ - Model(name="captcha"), - ]) - def get(self, name=None): - if not name: - return [ Model(**model) for model in self._models ] + return [ self ] if not name or self.name == name else [] + + def loss_factor(self, k): + return self.loss_factors[k] if k in self.loss_factors else 1.0 - for model in self._models: - if model.name == name: - return model + @property + # required for fp8 as the lengths needs to be divisible by 8 + def input_alignment(self): + return 8 if cfg.optimizations.fp8 else 0 - raise ValueError + @property + def activation_checkpointing(self): + return cfg.trainer.activation_checkpointing + + @property + def gradient_checkpointing(self): + return cfg.trainer.gradient_checkpointing + @property + def lora_policy(self): + include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever) + exclude = [] + + if self.arch_type == "llama": + include = ["self_attn", "mlp"] # target only the attention + mlp + exclude = ["self_attn.k_proj"] # common literature says to ignore it + if self.arch_type == "retnet": + include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff + exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet + + return dict(include=include, exclude=exclude) + +# should be renamed to Adapters +@dataclass() +class LoRA: + name: str = "lora" # vanity name + # to-do: find sane default values + rank: int = 128 # rank for the LoRA + alpha: int = 128 # rank for the LoRA + training: bool = True # + embeddings: bool = False # train the embedding too + parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not + rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA + + @property + def full_name(self): + name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ] + return "-".join(name) + + # actually not needed anymore + def active_level( self, level ): + if not self.rvq_levels: + return True + return level in self.rvq_levels + @dataclass() class Hyperparameters: - batch_size: int = 8 - gradient_accumulation_steps: int = 32 - gradient_clipping: int = 100 # to be implemented in the local backend + batch_size: int = 8 # number of samples per training batch + gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating + gradient_clipping: int | float = 10 # largest size a gradient norm can be - optimizer: str = "Adamw" - learning_rate: float = 3.25e-4 + optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now + optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config + + learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt + warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed - scheduler_type: str = "" # to be implemented in the local backend - scheduler_params: dict = field(default_factory=lambda: {}) + scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter + scheduler_type: str = "" # deprecated + scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config + + autotune: bool = False # to do deepspeed's autotuning + autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config + + torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied + torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied @dataclass() class Evaluation: - batch_size: int = 64 - frequency: int = 250 - size: int = 64 - + batch_size: int = 64 # number of samples per batch during eval / val + frequency: int = 250 # do eval / val every X iterations + size: int = 64 # number of samples to generate during eval / val + steps: int = 500 - temperature: float = 1.0 + temperature: float = 1.0 # AR temp for inferencing + + load_disabled_engines: bool = True # see the other load_disabled_engines @dataclass() class DeepSpeed: - zero_optimization_level: int = 0 - use_compression_training: bool = False + zero_optimization_level: int = 0 # doesn't seem to work + use_compression_training: bool = False # cope + compression_bits: int = 8 # cope + inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead + + amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently) - def get_ds_cfg(self, model): - weights = [ name[0] for name in model.named_parameters() ] - bits = 8 + config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config - scheduler_params = {} - for k in cfg.hyperparameters.scheduler_params: - scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] + @cached_property + def ds_cfg(self): + optimizer_params = cfg.hyperparameters.optimizer_params + + if 'lr' not in optimizer_params: + optimizer_params["lr"] = cfg.hyperparameters.learning_rate, - if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: + scheduler_params = cfg.hyperparameters.scheduler_params + if 'warmup_num_steps' not in scheduler_params: + scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps + + if 'total_num_steps' not in scheduler_params: scheduler_params['total_num_steps'] = cfg.trainer.iterations + autotune_params = cfg.hyperparameters.autotune_params + + if "enabled" not in autotune_params: + autotune_params['enabled'] = True + + if "results_dir" not in autotune_params: + autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" ) + + if "exps_dir" not in autotune_params: + autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" ) + + # DeepSpeed fp16 is incompatible with its AMP + if cfg.trainer.weight_dtype.lower() == "float16": + self.amp = False + + # disable local AMP + if self.amp: + cfg.trainer.amp = False + ds_cfg = { "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, "optimizer": { "type": cfg.hyperparameters.optimizer, - "params": { - "lr": cfg.hyperparameters.learning_rate, - } - }, + "params": optimizer_params, + } if not cfg.hyperparameters.torch_optimizer else None, "scheduler": { - "type": cfg.hyperparameters.scheduler_type, + "type": cfg.hyperparameters.scheduler, "params": scheduler_params, - } if cfg.hyperparameters.scheduler_type != "" else None, + } if not cfg.hyperparameters.torch_scheduler else None, "gradient_clipping": cfg.hyperparameters.gradient_clipping, "fp16": { - "enabled": True, - "auto_cast": True, - } if cfg.trainer.weight_dtype.lower() == "float16" else None, - "bf16": { - "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" + "enabled": cfg.trainer.weight_dtype.lower() == "float16", + "auto_cast": True, # ??? + "loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0, }, + "bf16": { + "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", + }, + "amp": { + "enabled": self.amp, + }, + "autotuning": autotune_params if cfg.hyperparameters.autotune else None, "compression_training": { "weight_quantization": { "shared_parameters":{ @@ -214,7 +324,7 @@ class DeepSpeed: "quantize_verbose": True, "quantization_type": "symmetric", "rounding": "nearest", - "quantize_weight_in_forward": True, + "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16 "fp16_mixed_quantize":{ "enabled": False, "quantize_change_ratio": 1 @@ -223,30 +333,38 @@ class DeepSpeed: "different_groups": { "wq1": { "params": { - "start_bits": bits, - "target_bits": bits, + "start_bits": self.compression_bits, + "target_bits": self.compression_bits, "quantization_period": 0 }, - "modules": weights + "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches } } }, "activation_quantization": { "shared_parameters":{ "enabled": True, + "quantizer_kernel": True, + "schedule_offset": 0, + "quantize_groups": 64, + "quantize_verbose": True, "quantization_type": "symmetric", - "range_calibration": "dynamic", - "schedule_offset": 0 + "rounding": "nearest", + "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16 + "fp16_mixed_quantize":{ + "enabled": False, + "quantize_change_ratio": 1 + } }, "different_groups": { "aq1": { "params": { - "bits": bits + "bits": self.compression_bits, }, - "modules": weights + "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches } } - } + }, } if self.use_compression_training else None, "zero_optimization": { "stage": self.zero_optimization_level, @@ -264,7 +382,10 @@ class DeepSpeed: "offload_param": { "device": "cpu", "pin_memory": True - } + }, + "zero_quantized_weights": self.use_compression_training, + "zero_hpz_partition_size": world_size(), + "zero_quantized_gradients": self.use_compression_training, } if self.zero_optimization_level > 0 else None, "comms_logger": { "enabled": False @@ -275,113 +396,314 @@ class DeepSpeed: for k in null_keys: del ds_cfg[k] - if os.path.exists("./config/ds_config.json"): - ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8"))) + if os.path.exists("./data/ds_config.json"): + ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8"))) + else: + ds_cfg.update(self.config) return ds_cfg @dataclass() class Trainer: - iterations: int = 100_000 + iterations: int = 1_000_000 # maximum iterations to train - save_tag: str = "step" - load_tag: str | None = None + save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count + load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name - save_on_oom: bool = True - save_on_quit: bool = True - save_frequency: int = 100 + save_on_oom: bool = True # save if an OOM error is raised + save_on_quit: bool = True # save when quitting training + + export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint + export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training + + save_frequency: int = 100 # frequency to save every X iterations - load_state_dict: bool = False - load_states: bool = True - strict_loading: bool = True - restart_step_count: bool = False + keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones - aggressive_optimizations: bool = False - check_for_oom: bool = True + load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists + load_states: bool = True # + strict_loading: bool = False # sets strict_loading=True when loading the state dict + load_module_only: bool = False # + restart_step_count: bool = False # clears the training stats when loading a checkpoint + resize_modules: bool = False # automatically resizes - gc_mode: str | None = None + activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing + gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training - weight_dtype: str = "float16" + aggressive_optimizations: bool = False # deprecated + check_for_oom: bool = True # checks for OOMs thrown during forward/backwards + gc_mode: str | None = None # deprecated, but marks when to do GC + load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation - backend: str = "deepspeed" + weight_dtype: str = "float16" # dtype to have the model under - deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) + amp: bool = False # automatic mixed precision + ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested + #scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload) + + load_webui: bool = False # not working, but loads the web UI to allow inferencing during training + no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet + + backend: str = "local" # training backend to use. currently supports "local" | "deepspeed" + deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings @cached_property def dtype(self): if self.weight_dtype == "float16": return torch.float16 - if cfg.trainer.weight_dtype == "bfloat16": + if self.weight_dtype == "bfloat16": return torch.bfloat16 + if self.weight_dtype == "float8_e5m2": + return torch.float8_e5m2 + if self.weight_dtype == "float8_e4m3fn": + return torch.float8_e4m3fn return torch.float32 + @cached_property + def scale_loss(self): + # currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways) + return self.dtype == torch.float16 @dataclass() class Inference: - use_vocos: bool = True # artifact from the VALL-E trainer + backend: str = "local" # backend to use when inferencing + weight_dtype: str = "float32" # dtype to load the model under + amp: bool = False # automatic mixed precision during inferencing + + normalize: bool = False # do NOT enable this unless you know exactly what you're doing + + @property + def dtype(self): + if self.weight_dtype == "float16": + return torch.float16 + if self.weight_dtype == "bfloat16": + return torch.bfloat16 + if self.weight_dtype == "int8": + return torch.int8 + if self.weight_dtype == "float8_e5m2": + return torch.float8_e5m2 + if self.weight_dtype == "float8_e4m3fn": + return torch.float8_e4m3fn + return torch.float32 @dataclass() -class BitsAndBytes: - enabled: bool = False - injects: bool = False +class Optimizations: + injects: bool = False # overwrites default torch classes (not recommended) + replace: bool = False # replaces modules in place with the optimized version (recommended) + compile: bool | str = False # runs torch.compile on the model - linear: bool = False - embedding: bool = False + linear: bool = True # inject/replace linear for BnB + embedding: bool = True # inject/replace embedding for BnB + optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation) + + bitsandbytes: bool = False # use bitsandbytes + dadaptation: bool = False # use dadaptation optimizer + bitnet: bool = False # use bitnet + fp8: bool = False # use fp8 + + model_offloading: dict | None = None # automatically splits the model over a list of devices + # example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU + # example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50% + # | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2 + + tensorrt: bool = False @dataclass() -class Config(_Config): - device: str = "cuda" +class Config(BaseConfig): + device: str = "cuda" # target device + mode: str = "training" # "inferencing" + experimental: bool = False # Debug flag, unused now dataset: Dataset = field(default_factory=lambda: Dataset) - models: Models = field(default_factory=lambda: Models) + models: dict | list | None = field(default_factory=lambda: []) + loras: dict | list | None = field(default_factory=lambda: []) hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) inference: Inference = field(default_factory=lambda: Inference) - bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) + bitsandbytes: dict | list | None = None # deprecated + optimizations: Optimizations = field(default_factory=lambda: Optimizations) - def get_device(self): - return torch.cuda.current_device() if self.device == "cuda" else self.device + tokenizer: str | None = None # tokenizer class + tokenizer_path: str = "./tokenizer.json" # tokenizer path + + weights_format: str = "pth" # "pth" | "sft" + supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"]) @property - def cache_dir(self): - return ".cache" / self.relpath + def model(self): + for i, model in enumerate(self.models): + if model.training: + return model + + return self.models[0] if len(self.models) > 0 else None + + # should be renamed to adapters + @property + def lora(self): + for i, lora in enumerate(self.loras): + if lora.training: + return lora + + return self.loras[0] if len(self.loras) > 0 else None + + @property + def distributed(self): + return world_size() > 1 @cached_property def diskcache(self): - if self.dataset.cache: + if self.yaml_path is not None and self.dataset.cache: return diskcache.Cache(self.cache_dir).memoize return lambda: lambda x: x + # I don't remember why this is needed def load_yaml( self, config_path ): tmp = Config.from_yaml( config_path ) self.__dict__.update(tmp.__dict__) + def load_hdf5( self, write=False ): + if hasattr(self, 'hdf5'): + self.hdf5.close() + + if self.distributed: + self.dataset.hdf5_flag = "r" + try: + self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset + except Exception as e: + _logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}") + self.dataset.use_hdf5 = False + + # to-do: prune unused keys + def format( self, training=True ): + if isinstance(self.dataset, type): + self.dataset = dict() + + if isinstance(self.models, type): + self.models = dict() + + if isinstance(self.loras, type): + self.loras = dict() + + if isinstance(self.hyperparameters, type): + self.hyperparameters = dict() + + if isinstance(self.evaluation, type): + self.evaluation = dict() + + if isinstance(self.trainer, type): + self.trainer = dict() + + if isinstance(self.inference, type): + self.inference = dict() + + if isinstance(self.optimizations, type): + self.optimizations = dict() + + self.dataset = Dataset(**self.dataset) + self.dataset.training = [ Path(dir) for dir in self.dataset.training ] + self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] + + self.models = [ Model(**model) for model in self.models ] + self.loras = [ LoRA(**lora) for lora in self.loras ] + + if not self.models: + self.models = [ Model() ] + + self.hyperparameters = Hyperparameters(**self.hyperparameters) + + self.evaluation = Evaluation(**self.evaluation) + + self.trainer = Trainer(**self.trainer) + + if not isinstance(self.trainer.deepspeed, type): + self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) + + self.inference = Inference(**self.inference) + + if self.bitsandbytes is not None: + self.optimizations = Optimizations(**self.bitsandbytes) + else: + self.optimizations = Optimizations(**self.optimizations) + + if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler: + self.hyperparameters.scheduler = self.hyperparameters.scheduler_type + self.hyperparameters.scheduler_type = "" + + # do not combine the two + if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation: + self.hyperparameters.scheduler = "" + + if self.hyperparameters.scheduler == "": + self.hyperparameters.torch_scheduler = True + + if self.trainer.backend == "local" and self.distributed: + self.trainer.ddp = True + + if self.trainer.activation_checkpointing is not None: + self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing + + if not training: + self.dataset.use_hdf5 = False + + # load our HDF5 file if requested here + if self.dataset.use_hdf5: + self.load_hdf5() + + # load tokenizer + if cfg.tokenizer == "naive": + cfg.tokenizer = NaiveTokenizer() + else: + try: + from transformers import PreTrainedTokenizerFast + + tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None + if tokenizer_path and not tokenizer_path.exists(): + tokenizer_path = Path("./data/") / cfg.tokenizer_path + + if tokenizer_path and tokenizer_path.exists(): + cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) + else: + cfg.tokenizer = NaiveTokenizer() + except Exception as e: + cfg.tokenizer = NaiveTokenizer() + _logger.warning(f"Error while parsing tokenizer: {str(e)}") + pass + + +# Preserves the old behavior +class NaiveTokenizer: + def get_vocab( self ): + """ + if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: + return json.loads( cfg.hdf5['symmap'].asstr()[()] ) + """ + return { " ": 0, "": 1, "": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 } + + @cached_property + def _bos_token( self ): + return self.get_vocab()[""] + + @cached_property + def _eos_token( self ): + return self.get_vocab()[""] + + def encode( self, s ): + symmap = self.get_vocab() + s = s.replace("O", "0") + s = [f""] + [ p if p in symmap else " " for p in s ] + [f""] + return [*map(symmap.get, s)] + +_logger = logging.getLogger(__name__) cfg = Config.from_cli() -# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves -cfg.dataset = Dataset(**cfg.dataset) -cfg.models = Models(**cfg.models) -cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters) -cfg.evaluation = Evaluation(**cfg.evaluation) -cfg.trainer = Trainer(**cfg.trainer) -cfg.inference = Inference(**cfg.inference) -cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) - -cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed) - -# cached_property stopped working... -if cfg.dataset.use_hdf5: - try: - cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a') - except Exception as e: - print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) - cfg.dataset.use_hdf5 = False - -if not cfg.dataset.use_hdf5: - cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] - cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] +# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types +try: + cfg.format() +except Exception as e: + _logger.error(f"Error while parsing config YAML: {str(e)}") + raise e # throw an error because I'm tired of silent errors messing things up for me if __name__ == "__main__": - print(cfg) \ No newline at end of file + print(cfg) diff --git a/image_classifier/data.py b/image_classifier/data.py index 51fdf2c..8b5f011 100755 --- a/image_classifier/data.py +++ b/image_classifier/data.py @@ -1,16 +1,19 @@ # todo: clean this mess up import copy -# import h5py +import h5py import json import logging -#import numpy as np +import numpy as np import os import random import torch -import math +import itertools from .config import cfg +from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler +from .utils.distributed import global_rank, local_rank, world_size +from .utils.io import torch_save, torch_load from collections import defaultdict from functools import cache, cached_property @@ -20,23 +23,57 @@ from typing import Any from torch import Tensor from torch.utils.data import DataLoader, Dataset as _Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.nn.utils.rnn import pad_sequence + +from PIL import Image, ImageDraw import torchvision.transforms as transforms + from tqdm.auto import tqdm - -from PIL import Image - # torch.multiprocessing.set_sharing_strategy("file_system") _logger = logging.getLogger(__name__) -@cache +# to-do: clean up this symmap mess def get_symmap(): - return { " ": 0, "": 1, "": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 } + return cfg.tokenizer.get_vocab() -@cache -def _get_symbols( content ): - content = content.replace("O", "0") - return [f""] + [ p for p in content ] + [f""] +def tokenize( s ): + if isinstance( s, list ): + s = "".join( s ) + return cfg.tokenizer.encode( s ) + +""" +def _replace_file_extension(path, suffix): + return (path.parent / path.name.split(".")[0]).with_suffix(suffix) + +def _get_hdf5_path(path): + # to-do: better validation + return str(path) + +def _get_hdf5_paths( data_dir, type="training", validate=False ): + data_dir = str(data_dir) + + key = f"/{type}/{_get_hdf5_path(data_dir)}" + + return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else [] + +def _get_paths_of_extensions( path, validate=False ): + if isinstance(path, str): + path = Path(path) + + return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else [] + +def _interleaved_reorder(l, fn): + groups = defaultdict(list) + for e in l: + groups[fn(e)].append(e) + groups = {k: groups[k] for k in sorted(groups)} + for interleaved in zip_longest(*groups.values()): + for value in interleaved: + if value is not None: + yield value +""" class Dataset(_Dataset): def __init__( @@ -44,43 +81,90 @@ class Dataset(_Dataset): paths, width=300, height=80, + stacks=0, symmap=get_symmap(), training=False, ): super().__init__() - self._head = None - - self.paths = paths + self.sampler = None self.width = width self.height = height - + self.stacks = stacks + + self.paths = paths + self.image_dtype = cfg.trainer.dtype self.symmap = symmap + self.training = training + self.dataset_type = "training" if self.training else "validation" + self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation self.transform = transforms.Compose([ - #transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow + transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) - @cached_property - def symbols(self): - return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths])) + # to-do: do not do validation if there's nothing in the validation + # this just makes it be happy + if len(self.dataset) == 0: + self.dataset = cfg.dataset.training + # split dataset accordingly per GPU + if cfg.distributed and self.training: + self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ] + + if len(self.paths) == 0: + raise ValueError(f"No valid path is found for {self.dataset_type}") + + @cached_property + def sampler_state_dict_path(self): + return cfg.rel_path / f"sampler.rank{global_rank()}.pt" + + def save_state_dict(self, path = None): + """ + if path is None: + path = self.sampler_state_dict_path + if self.sampler is not None: + state_dict = self.sampler.get_state() + elif self.samplers is not None: + state_dict = { + "samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() }, + } + torch_save(state_dict, path) + """ + return + + def load_state_dict(self, path = None): + """ + if path is None: + path = self.sampler_state_dict_path + + if not path.exists(): + return + + state_dict = torch_load(path) + if self.sampler is not None: + state_dict = self.sampler.set_state(state_dict) + else: + for name, sampler in state_dict["samplers"].items(): + if name not in self.samplers: + continue + self.samplers[name].set_state( sampler ) + """ + return def __getitem__(self, index): path = self.paths[index] + tokens = tokenize( path.stem.upper() ) + text = torch.tensor( tokens ).to(dtype=torch.uint8) - # stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here - try: - text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8) - except Exception as e: - print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem) - raise e + image = Image.open(path).convert('RGB') + width, height = image.size - image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB + image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB return dict( index=index, @@ -98,11 +182,6 @@ class Dataset(_Dataset): def __len__(self): return min(len(self.paths), self._head or len(self.paths)) - def pin_memory(self): - self.text = self.text.pin_memory() - self.image = self.image.pin_memory() - return self - def collate_fn(samples: list[dict]): batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]} @@ -111,21 +190,28 @@ def collate_fn(samples: list[dict]): def _seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 - #np.random.seed(worker_seed) + np.random.seed(worker_seed) random.seed(worker_seed) def _create_dataloader(dataset, training): + kwargs = dict( + shuffle=True, + batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, + drop_last=training, + sampler=dataset.sampler if training else None, + ) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict( + batch_sampler=dataset.sampler, + ) + return DataLoader( dataset=dataset, - batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, - shuffle=True, # training - drop_last=training, num_workers=cfg.dataset.workers, collate_fn=collate_fn, - persistent_workers=cfg.dataset.workers > 0, - pin_memory=False, # True, + persistent_workers=cfg.dataset.workers > 1, + pin_memory=False, worker_init_fn=_seed_worker, + **kwargs, ) def _load_train_val_paths( val_ratio=0.1 ): @@ -133,8 +219,8 @@ def _load_train_val_paths( val_ratio=0.1 ): train_paths = [] val_paths = [] - print(cfg.dataset.training) for data_dir in cfg.dataset.training: + paths.extend(data_dir.rglob("*.jpg")) paths.extend(data_dir.rglob("*.png")) if len(paths) > 0: @@ -146,12 +232,13 @@ def _load_train_val_paths( val_ratio=0.1 ): val_len = math.floor(len(train_paths) * val_ratio) train_len = math.floor(len(train_paths) * (1 - val_ratio)) - print(val_len, train_len) - val_paths = train_paths[:-val_len] train_paths = train_paths[:train_len] else: + paths = [] + for data_dir in cfg.dataset.validation: + paths.extend(data_dir.rglob("*.jpg")) paths.extend(data_dir.rglob("*.png")) if len(paths) > 0: @@ -169,7 +256,6 @@ def _load_train_val_paths( val_ratio=0.1 ): return train_paths, val_paths -@cfg.diskcache() def create_datasets(): train_paths, val_paths = _load_train_val_paths() @@ -187,10 +273,10 @@ def create_datasets(): return train_dataset, val_dataset - def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() + # deepcopy is slow subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset.head_(cfg.evaluation.size) subtrain_dataset.training_(False) @@ -200,8 +286,6 @@ def create_train_val_dataloader(): subtrain_dl = _create_dataloader(subtrain_dataset, training=False) _logger.info(str(train_dataset.symmap)) - - _logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.") @@ -210,11 +294,305 @@ def create_train_val_dataloader(): return train_dl, subtrain_dl, val_dl +# parse dataset into better to sample metadata """ -if __name__ == "__main__": - create_dataset_hdf5() +def create_dataset_metadata( skip_existing=True ): + symmap = get_symmap() + + root = str(cfg.data_dir) + metadata_root = str(cfg.metadata_dir) - train_dl, subtrain_dl, val_dl = create_train_val_dataloader() - sample = train_dl.dataset[0] - print(sample) -""" + cfg.metadata_dir.mkdir(parents=True, exist_ok=True) + + def add( dir, type="training", audios=True, texts=True ): + name = str(dir) + name = name.replace(root, "") + + speaker_name = name + + metadata_path = Path(f"{metadata_root}/{speaker_name}.json") + metadata_path.parents[0].mkdir(parents=True, exist_ok=True) + + try: + metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) + except Exception as e: + metadata = {} + + if not os.path.isdir(f'{root}/{name}/'): + return + + # tqdm.write(f'{root}/{name}') + files = os.listdir(f'{root}/{name}/') + + # grab IDs for every file + ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } + + wrote = False + + for id in tqdm(ids, desc=f"Processing {name}"): + try: + quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}') + + if audios and not quant_path.exists(): + continue + + key = f'{type}/{speaker_name}/{id}' + + if skip_existing and id in metadata: + continue + + wrote = True + + if id not in metadata: + metadata[id] = {} + + utterance_metadata = {} + if audios: + # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt + dac = np.load(quant_path, allow_pickle=True)[()] + qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) + + if "text" in dac["metadata"]: + utterance_metadata["text"] = dac["metadata"]["text"] + if "phonemes" in dac["metadata"]: + utterance_metadata["phonemes"] = dac["metadata"]["phonemes"] + if "language" in dac["metadata"]: + utterance_metadata["language"] = dac["metadata"]["language"] + if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: + utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] + + for k, v in utterance_metadata.items(): + metadata[id][k] = v + + except Exception as e: + tqdm.write(f'Error while processing {id}: {e}') + + if wrote: + with open(str(metadata_path), "w", encoding="utf-8") as f: + f.write( json.dumps( metadata ) ) + + # training + for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): + add( data_dir, type="training" ) + + # validation + for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'): + add( data_dir, type="validation" ) + + # noise + for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'): + add( data_dir, type="noise", texts=False ) + +# parse yaml to create an hdf5 file +def create_dataset_hdf5( skip_existing=True ): + cfg.dataset.use_hdf5 = True + cfg.load_hdf5(write=True) + hf = cfg.hdf5 + + symmap = get_symmap() + + root = str(cfg.data_dir) + metadata_root = str(cfg.metadata_dir) + + + def add( dir, type="training", audios=True, texts=True ): + name = str(dir) + name = name.replace(root, "") + + # yucky + speaker_name = name + if "LibriTTS-R" in speaker_name: + speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") + + metadata_path = Path(f"{metadata_root}/{speaker_name}.json") + metadata_path.parents[0].mkdir(parents=True, exist_ok=True) + + metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) + + if not os.path.isdir(f'{root}/{name}/'): + return + + files = os.listdir(f'{root}/{name}/') + + # grab IDs for every file + ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } + + for id in tqdm(ids, desc=f"Processing {name}"): + try: + quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True + text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True + + if not quant_exists: + continue + + key = f'{type}/{speaker_name}/{id}' + + if skip_existing and key in hf: + continue + + group = hf.create_group(key) if key not in hf else hf[key] + + if id not in metadata: + metadata[id] = {} + + utterance_metadata = {} + + # audio + if audios: + dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] + qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) + + if "text" in dac["metadata"]: + utterance_metadata["text"] = dac["metadata"]["text"] + if "phonemes" in dac["metadata"]: + utterance_metadata["phonemes"] = dac["metadata"]["phonemes"] + if "language" in dac["metadata"]: + utterance_metadata["language"] = dac["metadata"]["language"] + if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: + utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] + + if "audio" not in group: + group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf') + + # text + if texts: + if not utterance_metadata and text_exists: + utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) + + phn = "".join(utterance_metadata["phonemes"]) + phn = cfg.tokenizer.encode(phn) + phn = np.array(phn).astype(np.uint8) + + if "text" not in group: + group.create_dataset('text', data=phn, compression='lzf') + + for k, v in utterance_metadata.items(): + group.attrs[k] = v + metadata[id][k] = v + + except Exception as e: + tqdm.write(f'Error while processing {id}: {e}') + + with open(str(metadata_path), "w", encoding="utf-8") as f: + f.write( json.dumps( metadata ) ) + + # training + for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): + add( data_dir, type="training" ) + + # validation + for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'): + add( data_dir, type="validation" ) + + # noise + for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'): + add( data_dir, type="noise", texts=False ) + + # write symmap + if "symmap" in hf: + del hf['symmap'] + + hf.create_dataset('symmap', data=json.dumps(symmap)) + hf.close() + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser("Save trained model to path.") + parser.add_argument("--action", type=str) + parser.add_argument("--tasks", type=str) + args, unknown = parser.parse_known_args() + + task = args.action + + cfg.dataset.workers = 1 + + if args.action == "hdf5": + create_dataset_hdf5() + elif args.action == "list-dataset": + dataset = [] + for group in os.listdir(cfg.data_dir): + for name in os.listdir(cfg.data_dir / group): + if len(os.listdir(cfg.data_dir / group / name)) == 0: + continue + dataset.append(f'{group}/{name}') + + _logger.info(json.dumps(dataset)) + elif args.action == "metadata": + create_dataset_metadata() + elif args.action == "sample": + train_dl, subtrain_dl, val_dl = create_train_val_dataloader() + + samples = { + "training": [ next(iter(train_dl)), next(iter(train_dl)) ], + "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ], + #"validation": [ next(iter(val_dl)), next(iter(val_dl)) ], + } + + Path("./data/sample-test/").mkdir(parents=True, exist_ok=True) + + for k, v in samples.items(): + for i in range(len(v)): + for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."): + try: + decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" ) + except Exception as e: + _logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}") + try: + decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" ) + except Exception as e: + _logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}") + v[i]['proms'][j] = v[i]['proms'][j].shape + v[i]['resps'][j] = v[i]['resps'][j].shape + + for k, v in samples.items(): + for i in range(len(v)): + _logger.info(f'{k}[{i}]: {v[i]}') + elif args.action == "validate": + train_dl, subtrain_dl, val_dl = create_train_val_dataloader() + + missing = set() + + for i in range(len( train_dl.dataset )): + batch = train_dl.dataset[i] + + text = batch['text'] + phonemes = batch['metadata']['phonemes'] + + decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ] + for i, token in enumerate(decoded): + if token != "": + continue + + phone = phonemes[i] + + _logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" ) + + missing |= set([phone]) + + _logger.info( f"Missing tokens: {missing}" ) + + + elif args.action == "tasks": + index = 0 + cfg.dataset.tasks_list = args.tasks.split(",") + + train_dl, subtrain_dl, val_dl = create_train_val_dataloader() + batch = next(iter(train_dl)) + + for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]): + if task not in cfg.dataset.tasks_list: + continue + + _logger.info( f'{text} {task} {cfg.model.resp_levels}') + _logger.info( f'{proms.shape} {resps.shape}' ) + + tokens = 0 + tokens += sum([ text.shape[0] for text in batch["text"] ]) + tokens += sum([ resps.shape[0] for resps in batch["resps"] ]) + _logger.info( f'{tokens}' ) + + decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" ) + decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" ) + break +""" \ No newline at end of file diff --git a/image_classifier/engines/__init__.py b/image_classifier/engines/__init__.py index f0879ec..99c10ec 100755 --- a/image_classifier/engines/__init__.py +++ b/image_classifier/engines/__init__.py @@ -1,6 +1,6 @@ from ..config import cfg -from ..utils.distributed import fix_unset_envs +from ..utils.distributed import fix_unset_envs, ddp_model fix_unset_envs() if cfg.trainer.backend == "deepspeed": @@ -8,4 +8,211 @@ if cfg.trainer.backend == "deepspeed": elif cfg.trainer.backend == "local": from .base import Engine -from .base import Engines, TrainFeeder, default_feeder \ No newline at end of file +from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine + +from ..models import get_models, get_model +from ..utils import wrapper as ml +from ..utils.io import torch_save, torch_load, pick_path +from ..models.lora import apply_lora, lora_load_state_dict + +import torch +import re +import logging + +_logger = logging.getLogger(__name__) + +deepspeed_available = False +try: + import deepspeed + deepspeed_available = True +except Exception as e: + pass + +from functools import cache + +@cache +def load_engines(training=True, **model_kwargs): + models = get_models(cfg.models, training=training, **model_kwargs) + engines = dict() + + for name, model in models.items(): + state = None + stats = None + lora = None + + inferencing = cfg.mode == "inferencing" or not model.config.training or not training + backend = cfg.inference.backend if inferencing else cfg.trainer.backend + loads_state_dict = cfg.trainer.load_state_dict # or inferencing + + checkpoint_path = cfg.ckpt_dir / name / "latest" + # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present + load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) + + # actually use the lora-specific checkpoint if available + if cfg.lora is not None: + checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest" + + # to handle the issue of training with deepspeed, but inferencing with local + if checkpoint_path.exists() and backend == "local": + tag = open(checkpoint_path).read() + checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) + + if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): + _logger.warning(f"Checkpoint missing, but weights found: {load_path}") + loads_state_dict = True + + # load state early + if loads_state_dict: + state = torch_load(load_path, device=cfg.device) + + # check if config is defined in state, and re-initialize the model + if "config" in state and False: + _logger.warning("Model config definition in weights, re-loading...") + config_state = state["config"] + model = get_model( config=cfg.model.__class__( *config_state ), training=training ) + + hyper_config = model.config + + optimizer = None + lr_scheduler = None + + dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype + amp = cfg.inference.amp if inferencing else cfg.trainer.amp + ddp = cfg.trainer.ddp + + engine_class = LocalEngine if backend == "local" else Engine + + # apply model replacers + if cfg.optimizations.replace and cfg.optimizations.linear: + model.model = ml.replace_linear( model.model ) + + if cfg.optimizations.replace and cfg.optimizations.embedding: + model.model = ml.replace_embedding( model.model ) + + for lora in cfg.loras: + model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize ) + + if inferencing: + model.config.training = False + + if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)): + optimizer_class = None + scheduler_class = None + + params = { + "lr": cfg.hyperparameters.learning_rate, + } + if cfg.hyperparameters.optimizer.lower() == "adamw": + params["betas"] = (0.9, 0.96) + params["eps"] = 1e-07 + params["weight_decay"] = 0.01 + + # for dadaptation since it has Adam only + if ml.AdamW == ml.Adam: + params["decouple"] = True + + optimizer_class = ml.AdamW + elif cfg.hyperparameters.optimizer.lower() == "sgd": + optimizer = ml.SGD + elif cfg.hyperparameters.optimizer.lower() == "prodigy": + optimizer_class = ml.Prodigy + + params['d_coef'] = params['lr'] + params['lr'] = 1.0 + elif cfg.hyperparameters.optimizer.lower() == "adagrad": + optimizer_class = ml.Adagrad + else: + raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') + + params.update(cfg.hyperparameters.optimizer_params) + + optimizer = optimizer_class( + [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], + **params, + ) + + if cfg.hyperparameters.scheduler.lower() == "schedulefree": + if cfg.hyperparameters.optimizer.lower() == "adamw": + scheduler_class = ml.schedulefree.AdamWScheduleFree + elif cfg.hyperparameters.optimizer.lower() == "sgd": + scheduler_class = ml.schedulefree.SGDScheduleFree + else: + raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}') + + optimizer = scheduler_class( + [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], + lr = params['lr'], + warmup_steps = cfg.hyperparameters.warmup_steps + ) + + """ + # set up our LR scheduler here + """ + + if inferencing: + optimizer = None + lr_scheduler = None + + # load state dict if requested / required + if loads_state_dict: + # state dict is not just the module, extract the extra trainer details + if "stats" in state: + stats = state["stats"] + + # do not load stats if we're training a LoRA + if cfg.lora is not None or cfg.trainer.restart_step_count: + stats = None + + if "module" in state: + state = state["module"] + + model.load_state_dict(state, strict=cfg.trainer.strict_loading) + + # load lora weights if exists + if cfg.lora is not None: + lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) + if lora_path.exists(): + _logger.info( f"Loaded LoRA state dict: {lora_path}" ) + + state = torch_load(lora_path, device=cfg.device) + state = state['lora' if 'lora' in state else 'module'] + lora_load_state_dict( model, state ) + + # wrap if DDP is requested + if ddp: + model = ddp_model(model) + # wrap optimization class + elif cfg.optimizations.compile: + model = ml.compile_model(model, backend=cfg.optimizations.compile) + # deepspeed inferencing + elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): + engine_class = LocalEngine + model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module + + # use base engine if requested + engines[name] = engine_class( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + + hyper_config=hyper_config, + stats=stats + ) + + + engines = Engines(engines) + engines.setup() + + # this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not + if not cfg.trainer.load_state_dict: + engines.load_checkpoint(training=not inferencing) + + # freeze requested params + for name, engine in engines.items(): + engine.freeze(freeze_all=False) + + # split models over requested devices + if cfg.optimizations.model_offloading: + engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading ) + + return engines diff --git a/image_classifier/engines/base.py b/image_classifier/engines/base.py index 24199ed..a077280 100755 --- a/image_classifier/engines/base.py +++ b/image_classifier/engines/base.py @@ -28,7 +28,9 @@ def default_feeder(engine, batch): from ..config import cfg from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device -from ..utils.distributed import init_distributed, distributed_initialized +from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed +from ..utils.io import torch_save, torch_load +from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict import logging import time @@ -39,40 +41,65 @@ import os from torch import Tensor from torch.distributed import all_reduce from typing import Any, Protocol +from functools import cached_property from .base import TrainFeeder +from ..utils import wrapper as ml _logger = logging.getLogger(__name__) -if not distributed_initialized() and cfg.trainer.backend == "local": - def _nop(): - ... - fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group - init_distributed(fn) +if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1: + init_distributed(torch.distributed.init_process_group) # A very naive engine implementation using barebones PyTorch -# to-do: implement lr_sheduling class Engine(): def __init__(self, *args, **kwargs): - self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype) + if 'hyper_config' in kwargs: + self.hyper_config = kwargs['hyper_config'] + kwargs.pop("hyper_config") + + self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype) self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None - self.global_steps = 0 - self.micro_steps = 0 - self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps + self.global_steps = kwargs.pop("global_steps", 0) + self.micro_steps = kwargs.pop("micro_steps", 0) + self.global_samples = kwargs.pop("global_samples", 0) + self.tokens_processed = kwargs.pop("tokens_processed", 0) - def freeze(self): - for p in self.module.parameters(): - if p.requires_grad: - p.requires_grad_(False) - self._frozen_params.add(p) + self._frozen_params = set() + + self.max_nan_losses = 8 + self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None + + self.current_batch_size = 0 + self._global_grad_norm = None + + def freeze(self, freeze_all=True): + # set to freeze + if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): + raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") + + # freeze non-LoRA params if requested + if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: + return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings ) + + for name, param in self.module.named_parameters(): + if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): + param.requires_grad_(False) + self._frozen_params.add(param) def unfreeze(self): for p in self._frozen_params: p.requires_grad_(True) self._frozen_params.clear() + @property + def _training(self): + if not hasattr(self, "hyper_config"): + return True + return self.hyper_config.training + @property def global_step(self): return self.global_steps @@ -81,8 +108,17 @@ class Engine(): def micro_step(self): return self.micro_steps - def train_batch_size(self): - return cfg.hyperparameters.batch_size + @property + def batch_size(self): + return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size + + @property + def gradient_accumulation_steps(self): + return cfg.hyperparameters.gradient_accumulation_steps + + @property + def gradient_clipping(self): + return cfg.hyperparameters.gradient_clipping def gather_attribute(self, *args, **kwargs): return gather_attribute(self.module, *args, **kwargs) @@ -91,42 +127,74 @@ class Engine(): return dispatch_attribute(self.module, *args, **kwargs) def save_checkpoint(self, save_dir, tag ): - save_path = save_dir / tag / "state.pth" - save_path.parent.mkdir(parents=True, exist_ok=True) - torch.save({ - "global_step": self.global_step, - "micro_step": self.micro_step, - "module": self.module.state_dict(), - "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, - "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - }, save_path) + if is_global_leader(): + module = self.module.state_dict() - open(save_dir / "latest", 'w').write( tag ) + # if training lora + # this is a separate path to override saving the weights + lora = None + if cfg.lora is not None: + lora, module = lora_get_state_dict( module, split = True ) + save_dir = cfg.ckpt_dir / cfg.lora.full_name + + save_path = save_dir / tag / f"state.{cfg.weights_format}" + save_path.parent.mkdir(parents=True, exist_ok=True) + + torch_save({ + "module": module, + "lora": lora, + "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, + "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, + + "stats": { + "global_step": self.global_step, + "micro_step": self.micro_step, + "global_samples": self.global_samples, + "tokens_processed": self.tokens_processed, + } + }, save_path) + + open(save_dir / "latest", 'w').write( tag ) + + torch.distributed.barrier() + + def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False): + # override to load the lora instead + if cfg.lora is not None: + load_dir = cfg.ckpt_dir / cfg.lora.full_name - def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): if tag is None: tag_path = load_dir / "latest" + if not tag_path.exists(): return + tag = open(tag_path).read() - load_path = load_dir / tag / "state.pth" + load_path = load_dir / tag / f"state.{cfg.weights_format}" + if not load_path.exists(): return - state = torch.load(load_path) - self.global_steps = state['global_step'] - self.micro_steps = state['micro_step'] - self.module.load_state_dict(state['module']) + state = torch_load(load_path, device=cfg.device) + + self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step'] + self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] + self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples'] + self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed'] + self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading) load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state if load_optimizer_states: - self.optimizer.load_state_dict(state['optimizer']) + self.optimizer.load_state_dict(state['optimizer']) #, device=cfg.device) if load_lr_scheduler_states: - self.lr_scheduler.load_state_dict(state['lr_scheduler']) + self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device) + + if 'lora' in state: + lora_load_state_dict( self.module, state['lora'] ) def eval(self): return self.module.eval() @@ -136,46 +204,80 @@ class Engine(): def to(self, *args, **kwargs): self.module = self.module.to(*args, **kwargs) - return self.module + if self.optimizer: + self.optimizer = self.optimizer.to(*args, **kwargs) + + return self def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) + @cached_property + def device(self): + return next(self.module.parameters()).device + def forward(self, *args, **kwargs): return self.module.forward(*args, **kwargs) def backward(self, loss): + if self.loss_scaler is not None: + return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward() return (loss / self.gradient_accumulation_steps).backward() + def step(self): with torch.set_grad_enabled(self.gradient_accumulation_steps > 1): self.micro_steps += 1 + self.global_samples += self.batch_size if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0: + torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping) + self.global_steps += 1 - self.optimizer.step() + if self.loss_scaler is not None: + self.loss_scaler.step(self.optimizer) + self.loss_scaler.update() + else: + self.optimizer.step() self.optimizer.zero_grad() + self._get_grad_norm() + + def _get_grad_norm(self): + t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ] + self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None + def get_lr(self): lrs = [] for param_group in self.optimizer.param_groups: - if 'lr' in param_group: + if 'd_coeff' in param_group: + lrs.append(param_group['d_coeff']) + elif 'lr' in param_group: lrs.append(param_group['lr']) return lrs def set_lr(self, lr): for param_group in self.optimizer.param_groups: - if 'lr' in param_group: + if 'd_coeff' in param_group: + param_group['d_coeff'] = lr + elif 'lr' in param_group: param_group['lr'] = lr def get_global_grad_norm(self): - return 0.0 + return self._global_grad_norm def traverse(self, *args, **kwargs): - self.forward(*args, **kwargs) + with ml.autocast(): + self.forward(*args, **kwargs) + losses = self.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum() + if torch.isnan(loss).any(): + self.max_nan_losses = self.max_nan_losses - 1 + if self.max_nan_losses < 0: + raise RuntimeError("Too many NaN losses detected.") + stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= self.gather_attribute("scalar") @@ -194,6 +296,8 @@ class Engines(dict[str, Engine]): def setup(self): self._global_step = 0 self._micro_step = 0 + self._batch_size = 0 + self._global_samples = 0 @property def global_step(self): @@ -203,6 +307,14 @@ class Engines(dict[str, Engine]): def micro_step(self): return self._micro_step + @property + def batch_size(self): + return self._batch_size + + @property + def global_samples(self): + return self._global_samples + def gather_attribute(self, *args, **kwargs): ret = {} for engine in self.values(): @@ -213,6 +325,50 @@ class Engines(dict[str, Engine]): for engine in self.values(): engine.dispatch_attribute(*args, **kwargs) + def export(self, userdata={}, callback=None, dtype=None, format=None): + if not format: + format = cfg.weights_format + format = format.lower() + + if dtype is None: + dtype = cfg.trainer.dtype + + for name, engine in self.items(): + module = engine.module.state_dict() + lora = None + save_path = cfg.ckpt_dir / name / f"fp32.{format}" + config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config + + # safety + for k, v in module.items(): + module[k] = v.to(dtype) + + if cfg.lora is not None: + lora, module = lora_get_state_dict( module, split = True ) + save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}" + + state_dict = { + 'module': module, + 'lora': lora, + "stats": { + "global_step": engine.global_step, + "micro_step": engine.micro_step, + "global_samples": engine.global_samples, + "tokens_processed": engine.tokens_processed, + }, + "userdata": userdata, + "config": config.__dict__ + } + + if lora is None: + del state_dict['lora'] + + if callback: + state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) + + torch_save(state_dict, save_path) + _logger.info(f"Exported {name} to {save_path}") + def save_checkpoint(self, tag=None): if not tag: tag = cfg.trainer.save_tag @@ -222,47 +378,67 @@ class Engines(dict[str, Engine]): cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) for name, engine in self.items(): - engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag) + if not engine._training: + continue - def load_checkpoint(self, tag=None): + save_dir = cfg.ckpt_dir / name + try: + engine.save_checkpoint(save_dir, tag=tag) + except Exception as e: + _logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}') + + # might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None] + if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader(): + checkpoints = [ d for d in list(save_dir.glob("*")) if d.is_dir() ] + checkpoints.sort(key=lambda x: x.stat().st_mtime) + checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints] + for d in checkpoints: + if not d.is_dir() or not d.exists(): + continue + _logger.info(f"Removing {d}") + for p in d.iterdir(): + p.unlink() + d.rmdir() + + def load_checkpoint(self, tag=None, training=True): if not tag: tag = cfg.trainer.load_tag for name, engine in self.items(): load_dir = cfg.ckpt_dir / name + engine.load_checkpoint( tag=tag, load_dir=load_dir, load_module_strict=cfg.trainer.strict_loading, - load_optimizer_states=cfg.trainer.load_states, - load_lr_scheduler_states=cfg.trainer.load_states, + load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states, + load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states, + load_module_only=cfg.trainer.load_module_only, ) if cfg.trainer.restart_step_count: engine.global_steps = 0 + engine.mocro_step = 0 + engine.global_samples = 0 + engine.tokens_processed = 0 # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler if cfg.hyperparameters.scheduler_type == "": self.set_lr(cfg.hyperparameters.learning_rate) - self._update_global_step() - self._update_micro_step() + self._update() def set_lr(self, lr): for engine in self.values(): + if not engine._training: + continue engine.set_lr(lr) - def _update_global_step(self): + def _update(self): for engine in self.values(): self._global_step = max(self._global_step, engine.global_step) - - def _update_micro_step(self): - for engine in self.values(): self._micro_step = max(self._micro_step, engine.micro_step) - - def train_batch_size(self): - batch_size = 0 - for engine in self.values(): - batch_size = max(batch_size, engine.train_batch_size()) + self._batch_size = max(self._batch_size, engine.batch_size) + self._global_samples = max(self._global_samples, engine.global_samples) def eval(self): for engine in self.values(): @@ -279,7 +455,10 @@ class Engines(dict[str, Engine]): stats.update(flatten_dict({ name.split("-")[0]: stat })) return stats - def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()): + def quit(self): + cleanup_distributed() + + def step(self, batch, feeder: TrainFeeder = default_feeder): total_elapsed_time = 0 stats: Any = dict() @@ -287,10 +466,11 @@ class Engines(dict[str, Engine]): if cfg.trainer.gc_mode == 'step': do_gc() - batch = to_device(batch, device) - for name, engine in self.items(): - #torch.cuda.synchronize() + if not engine._training: + continue + + device = engine.device if cfg.trainer.gc_mode == 'substep': do_gc() @@ -298,10 +478,9 @@ class Engines(dict[str, Engine]): start_time = time.time() tries = 4 - n_ooms = torch.zeros([], device=cfg.device) + n_ooms = torch.zeros([], device=device) - if cfg.trainer.aggressive_optimizations: - batch = to_device(batch, device) + batch = to_device(batch, device) if not cfg.trainer.check_for_oom: res = feeder( engine=engine, batch=batch ) @@ -311,7 +490,7 @@ class Engines(dict[str, Engine]): res = feeder( engine=engine, batch=batch ) break except RuntimeError as e: - print("Forward", str(e)) + _logger.error(f"Forward: {str(e)}") if "out of memory" not in str(e): self.save_checkpoint() @@ -329,7 +508,8 @@ class Engines(dict[str, Engine]): do_gc() continue - all_reduce(n_ooms) + if world_size() > 1: + all_reduce(n_ooms) if n_ooms.item() > 0: self.save_checkpoint() raise RuntimeError("Out of memory during forward pass!") @@ -340,7 +520,7 @@ class Engines(dict[str, Engine]): loss, engine_stats = res engine_stats |= self.gather_attribute("scalar") - n_ooms = torch.zeros([], device=cfg.device) + n_ooms = torch.zeros([], device=device) if cfg.trainer.aggressive_optimizations: batch = to_device(batch, 'cpu') @@ -348,10 +528,11 @@ class Engines(dict[str, Engine]): if not cfg.trainer.check_for_oom: engine.backward(loss) else: + # to-do: properly handle when one GPU throws an OOM because it just halts try: engine.backward(loss) except RuntimeError as e: - print("Backwards:", str(e)) + _logger.error(f"Backwards: {str(e)}") if "out of memory" not in str(e): self.save_checkpoint() @@ -359,10 +540,13 @@ class Engines(dict[str, Engine]): n_ooms += 1 - all_reduce(n_ooms) + if world_size() > 1: + all_reduce(n_ooms) + if n_ooms.item() > 0: self.save_checkpoint() - raise RuntimeError("Out of memory during backwards pass!") + + raise RuntimeError("Out of memory during backwards pass!") engine.step() @@ -370,27 +554,36 @@ class Engines(dict[str, Engine]): elapsed_time = time.time() - start_time total_elapsed_time += elapsed_time + grad_norm = engine.get_global_grad_norm() + loss_scale = 1 + if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None: + loss_scale = engine.optimizer.loss_scale + + if grad_norm is not None: + grad_norm /= loss_scale stats.update( flatten_dict( { name.split("-")[0]: dict( - loss=loss.item(), + **engine_stats, lr=engine.get_lr()[0], - grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation + grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm, + loss_scale=loss_scale if loss_scale != 1 else None, elapsed_time=elapsed_time, engine_step=engine.global_step, - **engine_stats, + samples_processed=engine.global_samples, + tokens_processed=engine.tokens_processed, ) } ), ) - self._update_global_step() - self._update_micro_step() - stats["batch_size"] = self.train_batch_size() # len(batch["text"]) - stats["elapsed_time"] = total_elapsed_time - stats["wall_time"] = time.time() - stats["global_step"] = self.global_step + self._update() + + if len(self.keys()) > 1: + stats["elapsed_time"] = total_elapsed_time + + stats["it"] = self.global_step return stats diff --git a/image_classifier/engines/deepspeed.py b/image_classifier/engines/deepspeed.py index 8458807..4abe44f 100755 --- a/image_classifier/engines/deepspeed.py +++ b/image_classifier/engines/deepspeed.py @@ -25,28 +25,71 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr from deepspeed.accelerator import get_accelerator from ..utils.distributed import init_distributed, distributed_initialized +from ..utils import wrapper as ml + +from ..models.lora import freeze_non_lora_weights if not distributed_initialized() and cfg.trainer.backend == "deepspeed": init_distributed(init_deepspeed_dist) class Engine(DeepSpeedEngine): def __init__(self, *args, **kwargs): - kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model']) + self.hyper_config = None + if 'hyper_config' in kwargs: + self.hyper_config = kwargs['hyper_config'] + kwargs.pop("hyper_config") + + kwargs['config'] = cfg.trainer.deepspeed.ds_cfg kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) + stats = { + "global_step": 0, + "micro_step": 0, + "global_samples": 0, + "tokens_processed": 0, + } + + # kwargs['stats'] = None will return None when popped + maybe_stats = kwargs.pop('stats', stats) + if maybe_stats is not None: + stats = maybe_stats + super().__init__(None, *args, **kwargs) self._frozen_params = set() - def freeze(self): - for p in self.module.parameters(): - if p.requires_grad: - p.requires_grad_(False) - self._frozen_params.add(p) + self.global_steps = stats["global_step"] + self.micro_steps = stats["micro_step"] + self.global_samples = stats["global_samples"] + self.tokens_processed = stats["tokens_processed"] + + self.max_nan_losses = 8 + self.current_batch_size = 0 + + def freeze(self, freeze_all=True): + # freeze non-LoRA params if requested + if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: + frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings ) + for param in frozen_params: + self._frozen_params.add( param ) + + return + + if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): + raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") + + for name, param in self.module.named_parameters(): + if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): + param.requires_grad_(False) + self._frozen_params.add(param) def unfreeze(self): - for p in self._frozen_params: - p.requires_grad_(True) + for param in self._frozen_params: + param.requires_grad_(True) self._frozen_params.clear() + + @property + def _training(self): + return self.hyper_config.training @property def global_step(self): @@ -54,7 +97,11 @@ class Engine(DeepSpeedEngine): @property def micro_step(self): - return self.micro_steps + return self.micro_steps + + @property + def batch_size(self): + return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size def gather_attribute(self, *args, **kwargs): return gather_attribute(self.module, *args, **kwargs) @@ -66,17 +113,40 @@ class Engine(DeepSpeedEngine): try: if hasattr(self.optimizer, 'param_groups'): for param_group in self.optimizer.param_groups: - param_group['lr'] = lr + param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr else: self.optimizer.set_lr(lr) except Exception as e: - print(str(e)) + _logger.warning(str(e)) + + # we'll just have to live with the LoRA weights living within our main weights + # they're easy to extract anyways + def load_checkpoint(self, load_dir, **kwargs ): + # override to load the lora instead + if cfg.lora is not None: + load_dir = cfg.ckpt_dir / cfg.lora.full_name + + return super().load_checkpoint( load_dir, **kwargs ) + + def save_checkpoint(self, save_dir, **kwargs ): + # override to save the lora instead + if cfg.lora is not None: + save_dir = cfg.ckpt_dir / cfg.lora.full_name + + return super().save_checkpoint( save_dir, **kwargs ) def traverse(self, *args, **kwargs): - self.forward(*args, **kwargs) + with ml.autocast(): + self.forward(*args, **kwargs) + losses = self.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum() + if torch.isnan(loss).any(): + self.max_nan_losses = self.max_nan_losses - 1 + if self.max_nan_losses < 0: + raise RuntimeError("Too many NaN losses detected.") + stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= self.gather_attribute("scalar") diff --git a/image_classifier/export.py b/image_classifier/export.py index 8ec2dfe..176c2f4 100755 --- a/image_classifier/export.py +++ b/image_classifier/export.py @@ -1,31 +1,67 @@ import argparse import torch +import torch.nn from .data import get_symmap -from .train import load_engines +from .engines import load_engines +from .config import cfg +from .models.lora import lora_get_state_dict +from .utils.io import torch_save, torch_load -def load_models(): - models = {} - engines = load_engines() - for name in engines: - model = engines[name].module.cpu() - models[name] = model +# yanks a LoRA from the training checkpoint +def extract_lora( state_dict, config = None, save_path = None, dtype = None ): + if dtype is None: + dtype = cfg.inference.dtype - return models + format = save_path.stem[1:] + + lora = state_dict["lora"] if "lora" in state_dict else None + # should always be included, but just in case + if lora is None and "module" in state_dict: + lora, module = lora_get_state_dict( state_dict["module"], split = True ) + state_dict["module"] = module + + if "lora" in state_dict: + state_dict["lora"] = None + + # should raise an exception since there's nothing to extract, or at least a warning + if not lora: + return state_dict + + # save lora specifically + # should probably export other attributes, similar to what SD LoRAs do + save_path = save_path.parent / f"lora.{format}" + torch_save( { + "module": lora, + "config": cfg.lora.__dict__ if cfg.lora is not None else None, + }, save_path ) + + return state_dict def main(): parser = argparse.ArgumentParser("Save trained model to path.") - parser.add_argument("path") - args = parser.parse_args() + parser.add_argument("--module-only", action='store_true') + parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to + parser.add_argument("--format", type=str, default="pth") # set target format to export weights under + args, unknown = parser.parse_known_args() - models = load_models() - for name in models: - model = models[name] + if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]: + raise Exception(f"Unknown requested format: {args.format}") - outpath = f'{args.path}/{name}.pt' - torch.save(model, outpath) - print(f"Exported {name} to {outpath}") + if args.module_only: + cfg.trainer.load_module_only = True + + if args.dtype != "auto": + cfg.trainer.weight_dtype = args.dtype + + # necessary to ensure we are actually exporting the weights right + cfg.inference.backend = cfg.trainer.backend + + engines = load_engines(training=False) # to ignore loading optimizer state + + callback = None + engines.export(userdata={"symmap": get_symmap()}, callback=callback, format=args.format) if __name__ == "__main__": main() \ No newline at end of file diff --git a/image_classifier/inference.py b/image_classifier/inference.py index 0e70703..dba3f93 100755 --- a/image_classifier/inference.py +++ b/image_classifier/inference.py @@ -1,53 +1,103 @@ import torch +import torchaudio +import soundfile +import time +import logging +_logger = logging.getLogger(__name__) + +from torch import Tensor +from einops import rearrange +from pathlib import Path + +from .utils import to_device, set_seed, wrapper as ml +from PIL import Image, ImageDraw import torchvision.transforms as transforms -from .config import cfg -from .export import load_models -from .data import get_symmap, _get_symbols +from .config import cfg, Config +from .models import get_models +from .engines import load_engines, deepspeed_available +from .data import get_symmap, tokenize + +if deepspeed_available: + import deepspeed class Classifier(): - def __init__( self, width=300, height=80, config=None, ckpt=None, device=cfg.get_device(), dtype="float32" ): + def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ): + self.loading = True + + # yes I can just grab **kwargs and forward them here + self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention ) + self.load_model() + + self.loading = False + + def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ): if config: + _logger.info(f"Loading YAML: {config}") cfg.load_yaml( config ) - self.loading = True + try: + cfg.format( training=False ) + cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing + except Exception as e: + raise e # throw an error because I'm tired of silent errors messing things up for me + + if amp is None: + amp = cfg.inference.amp + if dtype is None or dtype == "auto": + dtype = cfg.inference.weight_dtype + if device is None: + device = cfg.device + + cfg.device = device + cfg.mode = "inferencing" + cfg.trainer.backend = cfg.inference.backend + cfg.trainer.weight_dtype = dtype + cfg.inference.weight_dtype = dtype + self.device = device - - if ckpt: - self.load_model_from_ckpt( ckpt ) - else: - self.load_model_from_cfg( config ) + self.dtype = cfg.inference.dtype + self.amp = amp - self.model.eval() + self.model_kwargs = {} - self.width = width - self.height = height + def load_model( self ): + load_engines.cache_clear() + + self.engines = load_engines(training=False, **self.model_kwargs) + for name, engine in self.engines.items(): + if self.dtype != torch.int8: + engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + + self.engines.eval() + self.symmap = get_symmap() + + self.width = 300 + self.height = 80 self.transform = transforms.Compose([ transforms.Resize((self.height, self.width)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) + + _logger.info("Loaded model") - self.loading = False + @torch.inference_mode() + def inference( self, image, temperature=1.0 ): + model = None - def load_model_from_ckpt( self, ckpt ): - self.ckpt = ckpt - - self.model = torch.load(self.ckpt).to(self.device) - - def load_model_from_cfg( self, config_path ): - - models = load_models() - for name in models: - model = models[name] - self.model = model.to(self.device) + for name, engine in self.engines.items(): + model = engine.module break - def inference( self, image, temperature=1.0 ): - image = self.transform(image).to(self.device) + image = self.transform(image).to(self.device).to(self.dtype) + + with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): + answer = model( image=[image], sampling_temperature=temperature ) + + answer = [ "".join(answer) ] - answer = self.model( image=[image], sampling_temperature=temperature ) answer = answer[0].replace('', "").replace("", "") # it would be better to just slice between these, but I can't be assed return answer \ No newline at end of file diff --git a/image_classifier/models/__init__.py b/image_classifier/models/__init__.py old mode 100755 new mode 100644 index acfca70..49f5f1b --- a/image_classifier/models/__init__.py +++ b/image_classifier/models/__init__.py @@ -1,18 +1,19 @@ from .base import Model -def get_model(cfg): +def get_model(cfg, training=False): name = cfg.name model = Model( n_tokens=cfg.tokens, n_len=cfg.len, d_model=cfg.dim, + d_resnet=cfg.resnet, ) - model._cfg = cfg + model.config = cfg print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") return model -def get_models(models): - return { model.full_name: get_model(model) for model in models } +def get_models(models, training=False): + return { model.full_name: get_model(model, training=training) for model in models } diff --git a/image_classifier/models/base.py b/image_classifier/models/base.py index 99b0322..a1af8c0 100755 --- a/image_classifier/models/base.py +++ b/image_classifier/models/base.py @@ -12,7 +12,7 @@ from torch.distributions import Categorical from torch.nn.utils.rnn import pad_sequence from torch.utils.checkpoint import checkpoint from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision -from torchvision.models import resnet18 +from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152 from ..data import get_symmap @@ -20,12 +20,12 @@ class Model(nn.Module): def __init__( self, n_tokens: int = 0, # number of token types - n_len: int = 6, # how long a sequence can be + n_len: int = 12, # how long a sequence can be d_model: int = 512, + d_resnet: int = 18, ): super().__init__() - _symmap = get_symmap() self.symmap = { f'{v}': k for k, v in _symmap.items() } self.symmap['0'] = "" @@ -36,8 +36,26 @@ class Model(nn.Module): self.n_tokens = n_tokens self.n_len = n_len + 2 # start/stop tokens self.d_model = d_model + self.d_resnet = d_resnet - self.resnet = resnet18(pretrained=False) + ResNet = resnet18 + if d_resnet == 18: + print("Using resnet18") + ResNet = resnet18 + elif d_resnet == 34: + print("Using resnet34") + ResNet = resnet34 + elif d_resnet == 50: + print("Using resnet50") + ResNet = resnet50 + elif d_resnet == 101: + print("Using resnet101") + ResNet = resnet101 + elif d_resnet == 152: + print("Using resnet152") + ResNet = resnet152 + + self.resnet = ResNet(pretrained=False) self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len ) self.accuracy_metric = MulticlassAccuracy( @@ -61,33 +79,29 @@ class Model(nn.Module): sampling_temperature: float = 1.0, ): - x_list = torch.stack( image, dim=0 ) - - x = self.resnet( x_list ) - y = x.view(x.size(0), self.n_len, self.n_tokens) + logits = self.resnet( torch.stack( image, dim=0 ) ) + logits = logits.view(logits.size(0), self.n_len, self.n_tokens).permute(1, 0, 2) - # either of these should do, but my VALL-E forward pass uses this, so might as well keep to it - # pred = y.argmax(dim=2) - pred = Categorical(logits=y / sampling_temperature).sample() - - answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ] + pred = logits.argmax(dim=2) if text is not None: - y_list = rearrange(pad_sequence(text), "t b -> b t") - - loss = 0 + labels = rearrange(pad_sequence(text), "t b -> b t").permute(1, 0) + loss = [] for i in range(self.n_len): - if i >= y_list.shape[1]: + if i >= labels.shape[0]: break - loss += F.cross_entropy( y[:, i], y_list[:, i] ) - + loss.append( F.cross_entropy(logits[i], labels[i]) ) + self.loss = dict( - nll=loss + nll = sum( loss ) / len( loss ), ) self.stats = dict( - acc = self.accuracy_metric( pred, y_list ), - precision = self.precision_metric( pred, y_list ), + acc = self.accuracy_metric( pred, labels ), + precision = self.precision_metric( pred, labels ), ) + + + answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ] return answer \ No newline at end of file diff --git a/image_classifier/models/lora.py b/image_classifier/models/lora.py new file mode 100644 index 0000000..87e6d7c --- /dev/null +++ b/image_classifier/models/lora.py @@ -0,0 +1,214 @@ +# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +from functools import partial +import torch +import torch.nn.functional as F +import torch.nn.utils.parametrize as parametrize + +from transformers.pytorch_utils import Conv1D + +from torch import Tensor, nn + +import math +from typing import Optional, List + +from ..utils import passes_policy + +# LoRA Linear for replacement +# Pros: simple, just needs to reuse the replace_linear and copy weights +# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you) +class LoRALinear(nn.Linear): + def __init__( + self, + + in_features: int, + out_features: int, + bias: bool = True, + + rank: int = 4, + alpha: int = 1, + + dropout: float = 0.1, + merge_weights: bool = False, + **kwargs, + ): + super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs) + + self.rank = rank + self.alpha = alpha + self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x + self.merge_weights = merge_weights + self.merged = False + self.enabled = True + + self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) ) + self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) ) + self.scaling = self.alpha / self.rank + + self.weight.requires_grad = False + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + # super silly but necessary because nn.Linear's constructor calls this + if hasattr(self, 'lora_A'): + nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) ) + nn.init.zeros_( self.lora_B ) + + def train(self, mode: bool = True): + super().train(mode) + + # training, separate lora from base weights + if mode and self.merge_weights and self.merged: + self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling + self.merged = False + + # not training, merge lora to base weights + if not mode and self.merge_weights and not self.merged: + self.weight.data += (self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if not self.merged and self.enabled: + result = F.linear(x, self.weight, bias=self.bias) + result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return result + + return F.linear(x, self.weight, bias=self.bias) + + @classmethod + def from_linear( cls, layer, device = None, dtype = None, **kwargs ): + if device is None: + device = layer.weight.device + if dtype is None: + dtype = layer.weight.dtype + return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + +# Uses parametrization to inject LoRA weights +# Pros: should work with any Linears +# Cons: TBD +class ParameterizedLoRA(nn.Module): + def __init__( + self, + + in_features: int, + out_features: int, + bias: bool = True, + + rank: int = 4, + alpha: int = 1, + + dropout: float = 0.1, + + device = None, + dtype = None + ): + super().__init__() + self.rank = rank + self.alpha = alpha + self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x + + self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype ) + self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype ) + self.scaling = self.alpha / self.rank + self.enabled = True + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) ) + nn.init.zeros_( self.lora_B ) + + def forward(self, x: torch.Tensor): + if self.enabled: + return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling + return x + + @classmethod + def from_linear( cls, layer, device = None, dtype = None, **kwargs ): + if device is None: + device = layer.weight.device + if dtype is None: + dtype = layer.weight.dtype + # swap because we're feeding the output as our input + # M$'s LoRA class arranges things to where this isn't necessary + return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + + @classmethod + def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ): + if device is None: + device = layer.weight.device + if dtype is None: + dtype = layer.weight.dtype + + in_channels, out_channels = layer.weight.shape + # swap because we're feeding the output as our input + # M$'s LoRA class arranges things to where this isn't necessary + return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + +def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ] + + for *parent, k in modules: + name = '.'.join(parent) + layer = getattr( model.get_submodule(name), k ) + + if isinstance( layer, nn.Linear ): + target = nn.Linear + klass = ParameterizedLoRA if use_parametrize else LoRALinear + replacer = klass.from_linear + elif isinstance( layer, nn.Conv1d ): + target = nn.Conv1d + klass = ParameterizedLoRA if use_parametrize else LoRAConv1d + replacer = klass.from_conv1d + elif isinstance( layer, Conv1D ): + target = Conv1D + klass = ParameterizedLoRA if use_parametrize else LoRAConv1d + replacer = klass.from_conv1d + else: + continue + + replacement = replacer( layer, device=device, dtype=dtype, **kwargs ) + + if use_parametrize: + parametrize.register_parametrization( layer, "weight", replacement ) + else: + setattr( model.get_submodule(name), k, replacement ) + + return enable_lora( model ) + +def enable_lora( model, mode = True ): + for name, module in model.named_modules(): + if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ): + continue + module.enabled = mode + return model + +def disable_lora( model ): + return enable_lora( model, False ) + +def freeze_non_lora_weights( model, embeddings = False ): + frozen_params = [] + + for name, param in model.named_parameters(): + should = 'lora_' in name or (embeddings and "_emb" in name) + + param.requires_grad_(should) + + if not should: + frozen_params.append( param ) + + return frozen_params + +def lora_get_state_dict( state_dict, split = True ): + lora = { name: param for name, param in state_dict.items() if "lora_" in name } + if not split: + return lora + + return lora, { name: param for name, param in state_dict.items() if "lora_" not in name } + +def lora_load_state_dict( model, state_dict ): + return model.load_state_dict( state_dict, strict = False ) \ No newline at end of file diff --git a/image_classifier/plot.py b/image_classifier/plot.py new file mode 100644 index 0000000..4eef13f --- /dev/null +++ b/image_classifier/plot.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +import argparse +import json +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +from .config import cfg + +def plot(paths, args): + dfs = [] + + for path in paths: + with open(path, "r") as f: + text = f.read() + + rows = [] + + pattern = r"(\{.+?\})\.\n" + + for row in re.findall(pattern, text, re.DOTALL): + try: + row = json.loads(row) + except Exception as e: + continue + + for model in args.models: + if f'{model.name}.{args.xs}' not in row: + continue + rows.append(row) + break + + df = pd.DataFrame(rows) + + if "name" in df: + df["name"] = df["name"].fillna("train") + else: + df["name"] = "train" + + df["group"] = str(path.parents[args.group_level]) + df["group"] = df["group"] + "/" + df["name"] + + dfs.append(df) + + df = pd.concat(dfs) + + if args.min_x is not None: + for model in args.models: + df = df[args.min_x < df[f'{model.name}.{args.xs}']] + + if args.max_x is not None: + for model in args.models: + df = df[df[f'{model.name}.{args.xs}'] < args.max_x] + + for gtag, gdf in sorted( + df.groupby("group"), + key=lambda p: (p[0].split("/")[-1], p[0]), + ): + for model in args.models: + x = f'{model.name}.{args.xs}' + for ys in args.ys: + y = f'{model.name}.{ys}' + + if gdf[y].isna().all(): + continue + + if args.min_y is not None: + gdf = gdf[args.min_y < gdf[y]] + if args.max_y is not None: + gdf = gdf[gdf[y] < args.max_y] + + gdf[y] = gdf[y].ewm(10).mean() + + gdf.plot( + x=x, + y=y, + label=f"{y}", + ax=plt.gca(), + marker="x" if len(gdf) < 100 else None, + alpha=0.7, + ) + + plt.gca().legend( + loc="center left", + fancybox=True, + shadow=True, + bbox_to_anchor=(1.04, 0.5), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--xs", default="engine_step") + parser.add_argument("--ys", nargs="+", default="") + parser.add_argument("--model", nargs="+", default="*") + + parser.add_argument("--min-x", type=float, default=-float("inf")) + parser.add_argument("--min-y", type=float, default=-float("inf")) + parser.add_argument("--max-x", type=float, default=float("inf")) + parser.add_argument("--max-y", type=float, default=float("inf")) + + parser.add_argument("--filename", default="log.txt") + parser.add_argument("--group-level", default=1) + args, unknown = parser.parse_known_args() + + path = cfg.rel_path / "logs" + paths = path.rglob(f"./*/{args.filename}") + + args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ] + + if args.ys == "": + args.ys = ["loss"] + + plot(paths, args) + + out_path = cfg.rel_path / "metrics.png" + plt.savefig(out_path, bbox_inches="tight") \ No newline at end of file diff --git a/image_classifier/samplers.py b/image_classifier/samplers.py new file mode 100644 index 0000000..74ef9f0 --- /dev/null +++ b/image_classifier/samplers.py @@ -0,0 +1,204 @@ +import math +import torch +import torch.nn.functional as F +import numpy as np + +from torch import Tensor, einsum, nn + +# Simple filter to modify a token's probability if it shows up in the past +# `one_time` will only apply the penalty once +# `decay` is a factor that will exponentially apply to how far away it is +def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ): + if factor == 1.0 or previous is None: + return logits + + unique = set() + priors = reversed(previous) + for distance, token in enumerate(priors): + # skip if we're only applying the decay once + if one_time and token in unique: + continue + + distance += 1 + logits[:, token] /= factor * (distance ** decay) + + # add to set if we care about it + if one_time: + unique.add(token) + + return logits + +# Simple "filter" that modifies the logit for the stop token, based on the sequence length +# `length` is the length of the sequence currently +# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences +# `token` is the stop token. +def length_penalize( logits, length, factor=0.0, token=-1 ): + if factor == 0.0: + return logits + + logits[:, token] /= (length ** factor) + return logits + +# Simple way to ban tokens +def ban_tokens( logits, tokens ): + for token in tokens: + # token not in logits + if logits.shape[-1] >= token: + continue + logits[:, token] = -float("inf") + return logits + +# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 +def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens per batch example in the output + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens > 1: + # Keep at least min_tokens (set to min_tokens-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + + return logits + +# credit to https://github.com/LostRuins/koboldcpp/pull/464 // https://github.com/kalomaze/koboldcpp/tree/dynamic-temp +def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.0, k = 10, sigmoidCenterPoint = 0.5 ): + # loop over logits[:], as the NAR will have logits.shape[0] > 1 + for i in range(logits.shape[0]): + sum_exp = 0.0 + maximum = torch.max( logits[i] ) + for logit in logits[i]: + sum_exp += math.exp( logit - maximum ) + + prob_max_token_before_temp = 1.0 / sum_exp + dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint))) + + logits[i] /= dynamic_temperature + + return logits + + + +# picks the top K tokens amongst a batch of logits +# logits: [Tensor] list of logits +# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from +def top_k_logits_list( logits_list, k ): + # ( batch, tokens ) => ( batch x tokens ) + logits = torch.cat( logits_list ) + candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits + for i, index in enumerate(candidates): + t = [] + N = np.prod(logits.size()) + for n in logits.size(): + N //= n + t.append(index // N) + index %= N + candidates[i] = tuple(t) + return candidates + + +# Credit to: https://github.com/basusourya/mirostat/ +# performs mirostat-based sampling +# logits: Tensor of logit probabilities +# state: the mirostat state +def mirostat_sample( logits, state = None ): + def compute_k(prob, n, tau): + num = 0 + den = 0 + for i in range(100): + b = prob[i]/prob[i+1] + t = (i+2)/(i+1) + num += math.log(b)*math.log(t) + den += math.log(t)**2 + + s = num/den + eps = s-1 + k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s) + k = round(k) + return k + + if "max_surprise" not in state: + state["max_surprise"] = state["tau"] * 2 + + if "error_surprise" not in state: + state["error_surprise"] = 0 + + if "running_total_surprise" not in state: + state["running_total_surprise"] = 0 + + sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True ) + prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist() + + k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1 + + sorted_logits = sorted_logits[0:k] + sorted_indices = sorted_indices[0:k] + prob_topk = torch.softmax(sorted_logits, dim = 0) + prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True) + + state["index_surprise"] = math.log2(1/prob_original[prev_i]) + state["running_total_surprise"] += state["index_surprise"] + state["error_surprise"] = state["index_surprise"] - state["tau"] + state["max_surprise"] -= state["eta"] * state["error_surprise"] + state["token"] = sorted_indices[prev_i] + + return state + +# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677 +# performs DRY sampling +# * (honestly it looks close to rep pen anyways but what do I know) +# `logits` are the scores used to sample against +# `previous` are the prior tokens to penalize with +# `factor` is the scalar multiplier +# `base` is the base number to raise to the (length - allowed_length)th power +# `allowed_length` limits the range to apply DRY to +def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ): + if factor == 0.0 or previous is None: + return logits + + lengths = {} + for i, token in enumerate( previous ): + length = 1 + while length < max(allowed_length, 50): + j = i - length + + # Start of input reached. + if j < 0: + break + + # Start of match reached. + if previous[j] != previous[-length-1]: + break + + length += 1 + + lengths[token] = max(length, lengths[token]) if token in lengths else length + + for token, length in lengths.items(): + if length < allowed_length: + break + logits[:, token] -= factor * base ** (length - allowed_length) + + return logits \ No newline at end of file diff --git a/image_classifier/train.py b/image_classifier/train.py index 70de575..710e02b 100755 --- a/image_classifier/train.py +++ b/image_classifier/train.py @@ -4,7 +4,7 @@ from .config import cfg from .data import create_train_val_dataloader from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc -from .utils.trainer import load_engines +from .utils.distributed import is_global_leader import json import logging @@ -12,53 +12,43 @@ import random import torch import torch.nn.functional as F import traceback +import shutil from collections import defaultdict -from PIL import Image + from tqdm import tqdm +import argparse +from PIL import Image, ImageDraw _logger = logging.getLogger(__name__) + def train_feeder(engine, batch): - engine( image=batch["image"], text=batch["text"] ) + with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): + batch_size = len(batch["text"]) + engine.current_batch_size = batch_size - losses = engine.gather_attribute("loss") - stat = engine.gather_attribute("stats") + engine( image=batch["image"], text=batch["text"] ) - loss = torch.stack([*losses.values()]).sum() + losses = engine.gather_attribute("loss") + stat = engine.gather_attribute("stats") + + loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in stat.items()} + engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) + return loss, stats @torch.inference_mode() def run_eval(engines, eval_name, dl): - engines_stats = { - 'eval': eval_name - } - - model = None - names = [] - for name, engine in engines.items(): - names.append(name) - model = engine - break - - stats = defaultdict(list) stats['loss'] = [] - def process( name, batch, resps_list ): - for path, ref, hyp in zip(batch["path"], batch["text"], hyp): - continue - - for batch in tqdm(dl): - batch: dict = to_device(batch, cfg.device) - - res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature ) - + def process( name, batch, res, loss ): for path, ref, hyp in zip(batch["path"], batch["text"], res): hyp = hyp.replace('', "").replace("", "") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png") @@ -67,36 +57,74 @@ def run_eval(engines, eval_name, dl): image = Image.open(path).convert('RGB') image.save(hyp_path) - losses = engine.gather_attribute("loss") - loss = torch.stack([*losses.values()]).sum().item() - stats['loss'].append(loss) + + processed = 0 + while processed < cfg.evaluation.size: + batch = to_device(next(iter(dl)), cfg.device) + + # limit to eval batch size in the event we somehow have a weird dataloader + for key in batch.keys(): + batch[key] = batch[key][:cfg.evaluation.batch_size] + + processed += len(batch["text"]) + + for name in engines: + engine = engines[name] + + res = engine( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature ) + losses = engine.gather_attribute("loss") + loss = torch.stack([*losses.values()]).sum().item() + + process( name, batch, res, loss ) + stats = {k: sum(v) / len(v) for k, v in stats.items()} - engines_stats.update(flatten_dict({ name: stats })) - - iteration = engines.global_step - engines_stats['it'] = iteration - engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) + engines_stats = { + f'{name}.{eval_name}': stats, + "it": engines.global_step, + } + #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") -def main(): +def train(): + parser = argparse.ArgumentParser("ResNet Image Classifier") + parser.add_argument("--eval", action="store_true", default=None) + args, unknown = parser.parse_known_args() + + # create log folder setup_logging(cfg.log_dir) + # copy config yaml to backup + if cfg.yaml_path is not None and is_global_leader(): + shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" ) train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def eval_fn(engines): + do_gc() + engines.eval() + # wrapped in a try block because it's sometimes prone to breaking try: run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "val", val_dl) except Exception as e: - print("Error occurred while performing eval:", str(e)) - print(traceback.format_exc()) + _logger.warning(f"Error occurred while performing eval: {str(e)}") + _logger.warning(traceback.format_exc()) + engines.train() do_gc() + if args.eval: + return eval_fn(engines=trainer.load_engines()) + + """ + if cfg.trainer.load_webui: + from .webui import start + start(lock=False) + """ + trainer.train( train_dl=train_dl, train_feeder=train_feeder, @@ -104,4 +132,5 @@ def main(): ) if __name__ == "__main__": - main() \ No newline at end of file + # to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"` + train() diff --git a/image_classifier/utils/__init__.py b/image_classifier/utils/__init__.py index 96929f3..d79e335 100755 --- a/image_classifier/utils/__init__.py +++ b/image_classifier/utils/__init__.py @@ -7,4 +7,7 @@ from .utils import ( to_device, tree_map, do_gc, + set_seed, + passes_policy, + get_devices ) \ No newline at end of file diff --git a/image_classifier/utils/distributed.py b/image_classifier/utils/distributed.py index e80b0dd..2400fdd 100755 --- a/image_classifier/utils/distributed.py +++ b/image_classifier/utils/distributed.py @@ -8,6 +8,10 @@ import socket from functools import cache, wraps from typing import Callable +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + def get_free_port(): sock = socket.socket() sock.bind(("", 0)) @@ -15,13 +19,18 @@ def get_free_port(): _distributed_initialized = False -def init_distributed( fn ): - fn() +def init_distributed( fn, *args, **kwargs ): + torch.cuda.set_device(local_rank()) + fn(*args, **kwargs) _distributed_initialized = True def distributed_initialized(): return _distributed_initialized +def cleanup_distributed(): + dist.barrier() + dist.destroy_process_group() + @cache def fix_unset_envs(): envs = dict( @@ -44,10 +53,12 @@ def fix_unset_envs(): def local_rank(): return int(os.getenv("LOCAL_RANK", 0)) - def global_rank(): return int(os.getenv("RANK", 0)) +def world_size(): + return int(os.getenv("WORLD_SIZE", 1)) + def is_local_leader(): return local_rank() == 0 @@ -86,4 +97,7 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable: if fn is None: return wrapper - return wrapper(fn) \ No newline at end of file + return wrapper(fn) + +def ddp_model(model): + return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True) \ No newline at end of file diff --git a/image_classifier/utils/io.py b/image_classifier/utils/io.py new file mode 100644 index 0000000..afc2033 --- /dev/null +++ b/image_classifier/utils/io.py @@ -0,0 +1,88 @@ +import torch +import json + +from pathlib import Path +from safetensors import safe_open as sft_load +from safetensors.torch import save_file as sft_save + +def coerce_path( path ): + return path if isinstance( path, Path ) else Path(path) + +def pick_path( path, *suffixes ): + suffixes = [*suffixes] + + for suffix in suffixes: + p = path.with_suffix( suffix ) + if p.exists(): + return p + + return path + +def is_dict_of( d, t ): + if not isinstance( d, dict ): + return False + + return all([ isinstance(v, torch.Tensor) for v in d.values() ]) + +# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors +def state_dict_to_tensor_metadata( data: dict, module_key=None ): + metadata = None + + # is a state_dict, no need to coerce + if is_dict_of( data, torch.Tensor ): + return data, metadata + + # is maybe a dict with a state dict + metadata, coerce it + metadata = {} + target = module_key + if not target: + for k, v in data.items(): + # is a dict of tensors, our target + if is_dict_of( v, torch.Tensor ): + target = k + continue # continue to iterate to grab other metadata + + # not a dict of tensors, put it as metadata + try: + metadata[k] = json.dumps(v) + except Exception as e: + pass + + if not target: + raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}') + + return data[target], metadata + +def torch_save( data, path, module_key=None ): + path = coerce_path(path) + ext = path.suffix + + if ext in [".safetensor", ".sft"]: + data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key ) + + return sft_save( data, path, metadata ) + + return torch.save( data, path ) + +def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ): + path = coerce_path(path) + ext = path.suffix + + if ext in [".safetensor", ".sft"]: + state_dict = {} + with sft_load(path, framework=framework, device=device) as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + if load_metadata: + metadata = f.metadata() + for k, v in metadata.items(): + try: + metadata[k] = json.loads( v ) + except Exception as e: + pass + state_dict = { module_key: state_dict } | metadata + + return state_dict + + return torch.load( path, map_location=torch.device(device), weights_only=not unsafe ) \ No newline at end of file diff --git a/image_classifier/utils/sampler.py b/image_classifier/utils/sampler.py old mode 100755 new mode 100644 index 5db9606..d5b6b76 --- a/image_classifier/utils/sampler.py +++ b/image_classifier/utils/sampler.py @@ -1,48 +1,164 @@ -""" -A sampler that balances data by key_fns. - -MIT License - -Copyright (c) 2023 Zhe Niu - -niuzhe.nz@outlook.com -""" - -import random - - -class Sampler: - def __init__(self, l, key_fns): - self.tree = self._build(l, key_fns) - - def _build(self, l, key_fns) -> dict[dict, list]: - if not key_fns: - return l - - tree = {} - - key_fn, *key_fns = key_fns - - for x in l: - k = key_fn(x) - - if k in tree: - tree[k].append(x) - else: - tree[k] = [x] - - for k in tree: - tree[k] = self._build(tree[k], key_fns) - - return tree - - def _sample(self, tree: dict | list): - if isinstance(tree, list): - ret = random.choice(tree) - else: - key = random.choice([*tree.keys()]) - ret = self._sample(tree[key]) - return ret - - def sample(self): - return self._sample(self.tree) \ No newline at end of file +from dataclasses import dataclass +from typing import Any +import random + +import torch +from torch.utils.data import Sampler + +from .distributed import global_rank, local_rank, world_size + +# Randomly picks an index from an array of indices +class PoolSampler(): + def __init__( self, pool = [], keep_all = False, shuffle = False ): + self.length = len(pool) + self.shuffle = shuffle + self.global_pool = pool if keep_all else None + self.global_indices = [ i for i in range(self.length) ] + self.reset() + + def reset(self): + self.current_pool = [ i for i in self.global_indices ] + if self.shuffle: + random.shuffle(self.current_pool) + + def sample(self, pool = None): + if pool is None: + pool = self.global_pool + # check if we need to reset + index = random.choice( self.current_pool ) + # remove from pool + self.current_pool.remove(index) + # reset if needed + if len(self.current_pool) == 0: + self.reset() + # map indices to our real values + return pool[index] if pool is not None else index + + def __len__(self): + return self.length # len(self.current_pool) + + def __iter__(self): + while len(self.current_pool) > 0: + yield self.sample() + + def __call__(self, *args, **kwargs): + return self.sample(*args, **kwargs) + + def get_state(self): + return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool } + + def set_state(self, state): + self.length = state["length"] + self.global_pool = state["global_pool"] + self.global_indices = state["global_indices"] + self.current_pool = state["current_pool"] + +# "Samples" through a fixed sequence from 0 to length +# Necessary for our "shuffle+sort by duration+interleave" sampling method +# Allows saving and loading state +class OrderedSampler(Sampler): + def __init__( self, length ): + self.position = 0 + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + if self.position >= self.length: + self.position = 0 + + while self.position < self.length: + yield self.position + self.position += 1 + + def get_state(self): + return { "position": self.position, "length": self.length } + + def set_state(self, state): + self.position = state["position"] + self.length = state["length"] + +# Like the above, but will batch based on token count +class BatchedOrderedSampler(Sampler): + def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ): + self.position = 0 + self.batches = [] + self.shuffle = shuffle + + assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" + + current_batch = [] + current_size = 0 + current_index = 0 + for key, bucket in buckets.items(): + for path, duration in bucket: + # flush + should_flush = False + if max_duration > 0 and current_size + duration > max_duration: + should_flush = True + elif max_batch_size > 0 and len(current_batch) >= max_batch_size: + should_flush = True + + if should_flush and len(current_batch) > 0: + self.batches.append( current_batch ) + current_batch = [] + current_size = 0 + + current_batch.append( current_index ) + current_index += 1 + current_size += duration + + if self.shuffle: + random.shuffle(self.batches) + + def __len__(self): + return len(self.batches) + + def __iter__(self): + if self.position >= len(self.batches): + self.position = 0 + if self.shuffle: + random.shuffle(self.batches) + + while self.position < len(self.batches): + yield self.batches[self.position] + self.position += 1 + + def get_state(self): + return { "position": self.position, "batches": self.batches } + + def set_state(self, state): + self.position = state["position"] + self.batches = state["batches"] + +# Randomly samples indices from a given sequence from 0 to length +# Allows saving and loading state +class RandomSampler(Sampler): + def __init__( self, length ): + self.position = 0 + self.length = length + + self.generator = torch.Generator() + self.perm = torch.randperm(self.length, generator=self.generator) + + def __len__(self): + return self.length + + def __iter__(self): + if self.position >= self.length: + self.position = 0 + self.perm = torch.randperm(self.length, generator=self.generator) + + while self.position < self.length: + yield self.perm[self.position] + self.position += 1 + + def get_state(self): + return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() } + + def set_state(self, state): + self.position = state["position"] + self.length = state["length"] + self.perm = state["perm"] + self.generator.set_state(state["generator"]) \ No newline at end of file diff --git a/image_classifier/utils/trainer.py b/image_classifier/utils/trainer.py index 88f6358..7c64470 100755 --- a/image_classifier/utils/trainer.py +++ b/image_classifier/utils/trainer.py @@ -4,12 +4,13 @@ import humanize import json -import os import logging +import numpy as np import random import selectors import sys import torch +import os from functools import cache from torch.distributed import broadcast_object_list @@ -18,9 +19,10 @@ from tqdm import tqdm from typing import Protocol from ..config import cfg -from .distributed import init_distributed, distributed_initialized from .distributed import ( - fix_unset_envs, + init_distributed, + distributed_initialized, + world_size, global_leader_only, global_rank, is_global_leader, @@ -28,73 +30,15 @@ from .distributed import ( local_leader_only, ) -from ..engines import Engine, Engines, TrainFeeder, default_feeder -from ..models import get_models +from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines -from .utils import to_device, do_gc +from .utils import to_device, do_gc, truncate_json from ..utils import wrapper as ml +from ..data import get_symmap # should decouple from this trainer script _logger = logging.getLogger(__name__) -_engines: Engines _command: str -def get_global_step(): - try: - return _engines.global_step - except: - return None - -def get_micro_step(): - try: - return _engines.micro_step - except: - return None - -def get_cmd(): - try: - return _command - except: - raise RuntimeError("Trainer has not been setup. Have you called trainer.train?") - - -get_iteration = get_global_step - -def load_engines(): - models = get_models(cfg.models.get()) - engines = dict() - - for name in models: - model = models[name] - - optimizer = None - lr_scheduler = None - - if cfg.hyperparameters.optimizer.lower() == "adamw": - optimizer = ml.AdamW( - model.parameters(), - lr=cfg.hyperparameters.learning_rate, - betas=(0.9, 0.96), - eps=1e-07, - weight_decay=0.01, - ) - - if cfg.trainer.load_state_dict: - load_path = cfg.ckpt_dir / name / "fp32.pth" - model.load_state_dict(torch.load(load_path)) - - engines[name] = Engine( - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - ) - - engines = Engines(engines) - engines.setup() - - if not cfg.trainer.load_state_dict: - engines.load_checkpoint() - - return engines class EvalFn(Protocol): def __call__(self, *, engines: Engines): @@ -151,17 +95,16 @@ def _non_blocking_input(): l[0] = s - if distributed_initialized(): + if world_size() > 1: broadcast_object_list(l, src=0) _command = l[0] return _command - def _make_infinite_epochs(dl): while True: - _logger.info("New epoch starts.") - yield from tqdm(dl, "Epoch progress", dynamic_ncols=True) + #_logger.info("New epoch starts.") + yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) @local_leader_only(default=None) @@ -172,30 +115,32 @@ def logger(data): def seed(seed): # Set up random seeds, after fork() random.seed(seed + global_rank()) - #np.random.seed(seed + global_rank()) + np.random.seed(seed + global_rank()) torch.manual_seed(seed + global_rank()) - def train( train_dl: DataLoader, train_feeder: TrainFeeder = default_feeder, eval_fn: EvalFn = lambda x: ..., logger: Logger = logger, ): - fix_unset_envs() - engines = load_engines() + # validate if there's at least one model to train + found = False + for name, engine in engines.items(): + if engine.training: + found = True + break + if not found: + raise Exception('Training, but no model loaded set to train...') + """ if is_local_leader(): cfg.dump() _logger.info(cfg) """ - # Setup global engines - global _engines - _engines = engines - events = [] eval_fn = global_leader_only(eval_fn) @@ -203,15 +148,20 @@ def train( # Pre-loop command command = _non_blocking_input() if command in ["eval", "eval_quit"]: - engines.eval() eval_fn(engines=engines) - engines.train() + if command in ["quit", "eval_quit"]: + engines.quit() return last_save_step = engines.global_step last_eval_step = 0 + """ + if cfg.distributed: + train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths))) + """ + # Training loop for batch in _make_infinite_epochs(train_dl): if engines.global_step >= cfg.trainer.iterations: @@ -219,17 +169,15 @@ def train( #batch = to_device(batch, torch.cuda.current_device()) stats = engines.step(batch=batch, feeder=train_feeder) - - iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps - stats['it'] = iteration - stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) - - del stats['batch_size'] - del stats['wall_time'] - del stats['global_step'] + stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size()) elapsed_time = stats.get("elapsed_time", 0) - _logger.info(f"Training Metrics: {json.dumps(stats)}.") + try: + metrics = json.dumps(stats) + except Exception as e: + metrics = str(stats) + + _logger.info(f"Training Metrics: {truncate_json(metrics)}.") command = _non_blocking_input() @@ -267,29 +215,48 @@ def train( if "lr" in command: rate = float(command.split(" ")[-1]) - engines.set_lr(rate) - print("Updating LR to:", rate) + try: + engines.set_lr(rate) + _logger.info(f"Updating LR to: {rate}") + except Exception as e: + _logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}") + + if "export" in command: + train_dl.dataset.save_state_dict() + engines.save_checkpoint() + last_save_step = engines.global_step + + if is_global_leader(): + engines.export(userdata={"symmap": get_symmap()}) save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency saving_commands = ["save"] + export_commands = ["export"] if cfg.trainer.save_on_quit: saving_commands.append("quit") + if cfg.trainer.export_on_quit: + export_commands.append("quit") + + if cfg.trainer.export_on_save: + export_commands.append("save") + if engines.global_step != last_save_step: if engines.global_step % save_ckpt_every == 0 or command in saving_commands: + train_dl.dataset.save_state_dict() engines.save_checkpoint() last_save_step = engines.global_step + + if command in export_commands and is_global_leader(): + engines.export(userdata={"symmap": get_symmap()}) if engines.global_step != last_eval_step: if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: - do_gc() - - engines.eval() - eval_fn(engines=engines) - engines.train() last_eval_step = engines.global_step + eval_fn(engines=engines) if command in ["quit"]: - return \ No newline at end of file + engines.quit() + return diff --git a/image_classifier/utils/utils.py b/image_classifier/utils/utils.py index 988f595..a253848 100755 --- a/image_classifier/utils/utils.py +++ b/image_classifier/utils/utils.py @@ -7,8 +7,16 @@ from .distributed import global_rank, local_rank, global_leader_only import gc import logging import pandas as pd +import numpy as np import re import torch +import random +import time +import psutil +import math +import logging + +_logger = logging.getLogger(__name__) from coloredlogs import ColoredFormatter from logging import StreamHandler @@ -16,9 +24,16 @@ from pathlib import Path from torch import Tensor, nn from tqdm.auto import tqdm from typing import Callable, TypeVar, overload - +from contextlib import contextmanager T = TypeVar("T") +def truncate_json( str ): + + def fun( match ): + return "{:.4f}".format(float(match.group())) + + return re.sub(r"\d+\.\d{8,}", fun, str) + def do_gc(): gc.collect() torch.cuda.empty_cache() @@ -28,6 +43,14 @@ def flatten_dict(d): return records[0] if records else {} +def set_seed(seed=None): + if not seed: + seed = int(time.time()) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + def _get_named_modules(module, attrname): for name, module in module.named_modules(): if hasattr(module, attrname): @@ -155,5 +178,363 @@ def tree_map(fn: Callable, x): return x -def to_device(x: T, device) -> T: - return tree_map(lambda t: t.to(device), x) +def to_device(x: T | None, *args, **kwargs) -> T: + if x is None: + return + + return tree_map(lambda t: t.to(*args, **kwargs), x) + +def coalese( *arg, return_last=True ): + return [ x for x in arg if x is not None ][-1 if return_last else 0] + +# checks if a module name is within a given whitelist/blacklist policy dict +def passes_policy( policy, name ): + if policy is None: + return True + + if "exclude" in policy: + for term in policy["exclude"]: + if term in name: + return False + + if "include" in policy: + for term in policy["include"]: + if term in name: + return True + + return False + +# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) +@contextmanager +def autocast(input, from_dtype, to_dtype): + if input.dtype == from_dtype: + input = input.to(to_dtype) + yield input + input = input.to(from_dtype) + else: + yield input + +@contextmanager +def autocasts(input, from_dtype, to_dtype): + if input.dtype in from_dtype: + from_dtype = input.dtype + input = input.to(to_dtype) + yield input + input = input.to(from_dtype) + else: + yield input + +# handles temporarily upcasting 'index tensors' so torch will stop bitching +def autocast_forward( func ): + def wrapper( self, input, *args, **kwargs ): + with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k: + return func( self, k, *args, **kwargs ) + return wrapper + +# handles migrating an input tensor to a given devicve +def auto_align_inputs_forward( module, device=None, name = None ): + func = module.forward + + if device is None: + if hasattr( module, 'device' ): + device = module.device + else: + try: + device = next(module.parameters() if [*module.parameters()] else module.buffers()).device + except Exception as e: + return func + + + def wrapper( *args, **kwargs ): + args = [*args] + # search through args and kwargs for any Tensor arguments + for i, arg in enumerate(args): + if not isinstance( arg, torch.Tensor ): + continue + args[i] = arg.to( device=device ) + + for k, v in kwargs.items(): + if not isinstance( v, torch.Tensor ): + continue + kwargs[k] = v.to( device=device ) + + # disgusting patch + if "position_embeddings" in kwargs: + kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ]) + + return func( *args, **kwargs ) + return wrapper + +# disgusting kludge, but it works (just realized BitNet has its own replacement routine) +# generalizing this would be super sugoi but the there's no catch all for arguments +def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ): + bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet + + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + kwargs = dict( + in_features = m.in_features, + out_features = m.out_features, + bias = m.bias is not None, + ) if not bnb else dict( + input_features=m.in_features, + output_features=m.out_features, + bias=m.bias is not None, + ) + + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + _logger.info(f"Replacing {name}.{k} to: {klass}") + + return model + +def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + kwargs = dict( + num_embeddings=m.num_embeddings, + embedding_dim=m.embedding_dim, + padding_idx=m.padding_idx, + max_norm=m.max_norm, + norm_type=m.norm_type, + scale_grad_by_freq=m.scale_grad_by_freq, + sparse=m.sparse, + ) + + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + _logger.info(f"Replacing {name}.{k} to: {klass}") + + return model + +# cannot feasibly do default arguments here sad +def replace_attention( model, klass, target, mode="math", verbose=False ): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + kwargs = dict( + config = m.config, + layer_idx = m.layer_idx, + mode = mode, + ) + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + _logger.info(f"Replacing {name}.{k} to: {klass}") + + return model + +# trim/expand a tensor (for example, in a state dict) +def resize_weight( weight, target, dim=0, random=True ): + # trim + if target < weight.shape[dim]: + return weight[:target] + # expand + if target > weight.shape[dim]: + fn = torch.rand if random else torch.zeros + return torch.stack( + [ x for x in weight ] + + [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ] + ) + + return weight + +def get_devices(): + return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] + +# grabs the memory properties of a given device +def get_device_properties( device ): + if 'cuda' in device: + props = torch.cuda.get_device_properties(device) + free, total = torch.cuda.mem_get_info(device) + else: + props = psutil.virtual_memory() + free, total = props.available, props.total + + return {"name": device, "total": total, "free": free, "props": props} + +# gets the rough size for a given module's parameters +def get_module_size( module ): + param_size = sum([p.nelement() * p.element_size() for p in module.parameters()]) + buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()]) + return param_size + buffer_size + +# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way +# it'd be better to just attach to layers itself rather than every single module + +# assigns modules to requested devices for a given policy +def get_model_offload_policy(module, policy=None): + # handle any other weird values this is set to + if not isinstance(policy, dict): + policy = {} + + # default to only include the core model, and not the other modules (embeddings) in the splitting policy + if "include" not in policy: + policy["include"] = ["model"] + + if "limits" not in policy: + policy["limits"] = [] + + if "assign" not in policy: + policy["assign"] = [] + + if "devices" not in policy: + policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget + + # create initial device info + devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ] + modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ] + # filter + modules = [ (name, size) for name, size in modules if name and size ] + + total_size = sum([size for name, size in modules]) + + # set caps if requested in the policy + for i, cap in enumerate(policy["limits"]): + # no limit, skip + if cap <= 0: + continue + # is fractional, scale to total size + if cap < 1: + cap = math.floor(total_size * cap) + # available space is below cap, don't set + if devices[i]["free"] < cap: + continue + # cap to requested size + devices[i]["free"] = cap + + # assign if specific parts of the model are requested for assignment + if policy["assign"]: + discarded = [] + # yuck, there has to be a better way + for device_index, includes in enumerate( policy["assign"] ): + device = devices[device_index] + + buffered_modules = [] + buffered_size = device["free"] + + # iterate through list of modules to compare against includes + for name, size in modules: + # doesn't pass policy + if not passes_policy( {"include": includes}, name ): + continue + # check if within budget + if buffered_size - size >= 0: + # add to buffer + buffered_modules.append( (name, size) ) + buffered_size -= size + # budget exceeded, flush buffer + else: + discarded += buffered_modules + buffered_modules = [] + buffered_size = 0 + break + + if buffered_modules and buffered_size: + device["modules"] += [ name for name, size in buffered_modules ] + device["free"] = buffered_size + + modules = discarded + + device_index = 0 + module_index = 0 + # assign modules to each device + while module_index < len(modules) and device_index < len(devices): + device = devices[device_index] + name, size = modules[module_index] + + # fits within budget + if device["free"] - size >= 0: + device["modules"].append( name ) + device["free"] -= size + module_index += 1 + # does not fit in budget, increase device index + else: + device_index += 1 + _logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB") + + # to-do: check that all modules are exhausted + assert module_index >= len(modules) + + # only return devices with modules assigned + return [ device for device in devices if device["modules"] ] + +# handles naively splitting a model's layers across multiple devices +# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU +def offload_model( model, policy=None ): + policy = get_model_offload_policy(model, policy=policy) + + # move modules to respective devices + for i, device in enumerate( policy ): + # nothing assigned, skip + if not device["modules"]: + continue + + for name in device["modules"]: + module = model.get_submodule(name) + module = module.to( device["name"] ) + module.device = device['name'] + + # wrap modules with forward to ensure all inputs are matched to its device + for name, module in model.named_modules(): + if not hasattr( module, 'forward' ): + continue + + module.forward = auto_align_inputs_forward(module) + + """ + # Validate that the layers are all in the right spot + for name, module in model.named_modules(): + if not not [*module.named_children()]: + continue + try: + _logger.info( name, next(module.parameters()).device ) + except Exception as e: + _logger.info( name, "?" ) + pass + """ + + return model \ No newline at end of file diff --git a/image_classifier/utils/wrapper.py b/image_classifier/utils/wrapper.py index 040762d..2cbd14b 100755 --- a/image_classifier/utils/wrapper.py +++ b/image_classifier/utils/wrapper.py @@ -1,20 +1,39 @@ from contextlib import contextmanager +import math import torch import torch.nn.functional as F +import logging + from ..config import cfg +_logger = logging.getLogger(__name__) + Embedding = torch.nn.Embedding Linear = torch.nn.Linear -if cfg.bitsandbytes.enabled: - import bitsandbytes as bnb - - if cfg.bitsandbytes.linear: - Linear = bnb.nn.Linear8bitLt +Adam = torch.optim.Adam +AdamW = torch.optim.AdamW +SGD = torch.optim.SGD +Adagrad = torch.optim.Adagrad - if cfg.bitsandbytes.embedding: - Embedding = bnb.nn.StableEmbedding +# https://github.com/kyegomez/BitNet +if cfg.optimizations.bitnet: + from bitnet import BitLinear + +if cfg.optimizations.bitsandbytes: + import bitsandbytes as bnb + + if cfg.optimizations.linear: + + if cfg.optimizations.bitnet: + Linear = BitLinear + else: + Linear = bnb.nn.Linear8bitLt + + if cfg.optimizations.embedding: + Embedding = bnb.nn.modules.Embedding + """ Embedding.forward = lambda self, input: ( self.norm(F.embedding( input, self.weight, @@ -24,52 +43,101 @@ if cfg.bitsandbytes.enabled: self.scale_grad_by_freq, self.sparse, )).to(self.weight.dtype) ) - -Adam = torch.optim.Adam -AdamW = torch.optim.AdamW - -if cfg.bitsandbytes.enabled: - import bitsandbytes as bnb - - Adam = bnb.optim.Adam - AdamW = bnb.optim.AdamW - -# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) -@contextmanager -def autocast(input, from_dtype, to_dtype): - if input.dtype == from_dtype: - input = input.to(to_dtype) - yield input - input = input.to(from_dtype) - else: - yield input - -@contextmanager -def autocasts(input, from_dtype, to_dtype): - if input.dtype in from_dtype: - from_dtype = input.dtype - input = input.to(to_dtype) - yield input - input = input.to(from_dtype) - else: - yield input - -# handles temporarily upcasting 'index tensors' so torch will stop bitching -def autocast_forward( func ): - def wrapper( self, input, *args, **kwargs ): - with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k: - return func( self, k, *args, **kwargs ) """ - if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8: - return func( self, input.to(torch.int32), *args, **kwargs ) - return func( self, input, *args, **kwargs ) - """ - return wrapper -Embedding.forward = autocast_forward(Embedding.forward) -if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: - torch.nn.Linear = Linear - torch.nn.Embedding = Embedding + if cfg.optimizations.optimizers: + Adam = bnb.optim.Adam8bit + AdamW = bnb.optim.AdamW8bit + SGD = bnb.optim.SGD8bit + Adagrad = bnb.optim.Adagrad8bit - torch.optim.Adam = Adam - torch.optim.AdamW = AdamW \ No newline at end of file +elif cfg.optimizations.dadaptation: + import dadaptation + + if cfg.optimizations.optimizers: + Adam = dadaptation.DAdaptAdam + AdamW = dadaptation.DAdaptAdam + SGD = dadaptation.DAdaptSGD + AdaGrad = dadaptation.DAdaptAdaGrad + +if cfg.optimizations.fp8: + import transformer_engine.pytorch as te + + Linear = te.Linear + + @contextmanager + def autocast(): + yield te.fp8_autocast(enabled=True) +else: + @contextmanager + def autocast(): + yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp) + +if cfg.optimizations.injects: + if cfg.optimizations.linear: + torch.nn.Linear = Linear + + if cfg.optimizations.embedding: + torch.nn.Embedding = Embedding + + if cfg.optimizations.optimizers: + torch.optim.Adam = Adam + torch.optim.AdamW = AdamW + torch.optim.SGD = SGD + +AVAILABLE_COMPILE_BACKENDS = [] + +try: + AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends() +except Exception as e: + pass + + +if cfg.optimizations.tensorrt: + try: + import torch_tensorrt + AVAILABLE_COMPILE_BACKENDS.append("tensorrt") + except Exception as e: + _logger.warning(f'Error while importing TensorRT: {str(e)}') + pass + +def compile_model(model, backend="auto"): + if not backend or backend == "auto": + backend = AVAILABLE_COMPILE_BACKENDS[0] + + if backend not in AVAILABLE_COMPILE_BACKENDS: + return torch.compile(model) + + return torch.compile(model, backend=backend) + +# https://github.com/konstmish/prodigy +try: + from prodigyopt import Prodigy +except Exception as e: + _logger.warning(f'Error while importing Prodigyopt: {str(e)}') + pass + +# https://github.com/facebookresearch/schedule_free/ +try: + import schedulefree +except Exception as e: + _logger.warning(f'Error while importing Schedule_Free: {str(e)}') + pass + +# backwards compat +from .utils import ( + autocast_forward, + replace_linear as replace_linear_old, + replace_embedding as replace_embedding_old, + replace_attention, + resize_weight, + offload_model, +) + +# wrapped here so we can maintain default args +def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ): + return replace_linear_old( model, klass, target, verbose ) +def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ): + return replace_embedding_old( model, klass, target, verbose ) + +Embedding.forward = autocast_forward(Embedding.forward) \ No newline at end of file diff --git a/image_classifier/webui.py b/image_classifier/webui.py new file mode 100644 index 0000000..4e0ce15 --- /dev/null +++ b/image_classifier/webui.py @@ -0,0 +1,220 @@ +import os +import re +import argparse +import random +import tempfile +import functools +from datetime import datetime + +import gradio as gr + +from time import perf_counter +from pathlib import Path +from PIL import Image + +from .inference import Classifier, cfg +from .train import train +from .utils import get_devices + +classifier = None + +layout = {} +layout["inference"] = {} +layout["training"] = {} +layout["settings"] = {} + +for k in layout.keys(): + layout[k]["inputs"] = { "progress": None } + layout[k]["outputs"] = {} + layout[k]["buttons"] = {} + +# there's got to be a better way to go about this +def gradio_wrapper(inputs): + def decorated(fun): + @functools.wraps(fun) + def wrapped_function(*args, **kwargs): + for i, key in enumerate(inputs): + kwargs[key] = args[i] + try: + return fun(**kwargs) + except Exception as e: + raise gr.Error(str(e)) + return wrapped_function + return decorated + +class timer: + def __init__(self, msg="Elapsed time:"): + self.msg = msg + + def __enter__(self): + self.start = perf_counter() + return self + + def __exit__(self, type, value, traceback): + msg = f'{self.msg} {(perf_counter() - self.start):.3f}s' + + gr.Info(msg) + print(f'[{datetime.now().isoformat()}] {msg}') + +# returns a list of models, assuming the models are placed under ./training/ or ./models/ +def get_model_paths( paths=[Path("./data/"), Path("./training/"), Path("./models/")] ): + yamls = [] + + for path in paths: + if not path.exists(): + continue + + for yaml in path.glob("**/*.yaml"): + if "/logs/" in str(yaml): + continue + + yamls.append( yaml ) + + return yamls + +def get_dtypes(): + return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"] + + +#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys()) +def load_model( yaml, device, dtype ): + gr.Info(f"Loading: {yaml}") + try: + init_classifier( yaml=Path(yaml), restart=True, device=device, dtype=dtype ) + except Exception as e: + raise gr.Error(e) + gr.Info(f"Loaded model") + +def init_classifier(yaml=None, restart=False, device="cuda", dtype="auto"): + global classifier + + if classifier is not None: + if not restart: + return classifier + + del classifier + classifier = None + + parser = argparse.ArgumentParser(allow_abbrev=False) + parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too + parser.add_argument("--device", type=str, default=device) + parser.add_argument("--amp", action="store_true") + parser.add_argument("--dtype", type=str, default=dtype) + args, unknown = parser.parse_known_args() + + classifier = Classifier( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) + return classifier + +@gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) +def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): + if not cfg.yaml_path: + raise Exception("No YAML loaded.") + + parser = argparse.ArgumentParser(allow_abbrev=False) + # I'm very sure I can procedurally generate this list + parser.add_argument("--image", type=str, default=kwargs["image"]) + parser.add_argument("--temp", type=float, default=kwargs["temp"]) + args, unknown = parser.parse_known_args() + + classifier = init_classifier() + + args.image = Image.open(args.image).convert('RGB') + + gr.Info("Inferencing...") + with timer("Inferenced in") as t: + answer = classifier.inference( + image=args.image, + temperature=args.temp, + ) + + return answer + +# setup args +parser = argparse.ArgumentParser(allow_abbrev=False) +parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', None)) # os environ so it can be specified in a HuggingFace Space too +parser.add_argument("--listen", default=None, help="Path for Gradio to listen on") +parser.add_argument("--share", action="store_true") +parser.add_argument("--render_markdown", action="store_true", default="CLASSIFIER_YAML" in os.environ) +args, unknown = parser.parse_known_args() + +args.listen_host = None +args.listen_port = None +args.listen_path = None +if args.listen: + try: + match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0] + + args.listen_host = match[0] if match[0] != "" else "127.0.0.1" + args.listen_port = match[1] if match[1] != "" else None + args.listen_path = match[2] if match[2] != "" else "/" + except Exception as e: + pass + +if args.listen_port is not None: + args.listen_port = int(args.listen_port) + if args.listen_port == 0: + args.listen_port = None + +# setup gradio +ui = gr.Blocks() +with ui: + with gr.Tab("Inference"): + with gr.Row(): + with gr.Column(scale=4): + layout["inference"]["inputs"]["image"] = gr.Image(label="Input Image", sources=["upload"], type="filepath") + layout["inference"]["outputs"]["output"] = gr.Textbox(label="Output") + with gr.Column(scale=4): + with gr.Row(): + layout["inference"]["inputs"]["temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature", info="Modifies the randomness from the samples. (0 to greedy sample)") + layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") + + layout["inference"]["buttons"]["inference"].click( + fn=do_inference, + inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None], + outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None] + ) + + """ + with gr.Tab("Training"): + with gr.Row(): + with gr.Column(scale=1): + layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log") + with gr.Row(): + with gr.Column(scale=1): + layout["training"]["buttons"]["train"] = gr.Button(value="Train") + + layout["training"]["buttons"]["train"].click( + fn=do_training, + outputs=[ x for x in layout["training"]["outputs"].values() if x is not None], + ) + """ + + with gr.Tab("Settings"): + with gr.Row(): + with gr.Column(scale=7): + with gr.Row(): + layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model") + layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device") + layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") + with gr.Column(scale=1): + layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model") + + layout["settings"]["buttons"]["load"].click( + fn=load_model, + inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None], + outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None], + ) + + if os.path.exists("README.md") and args.render_markdown: + md = open("README.md", "r", encoding="utf-8").read() + # remove HF's metadata + if md.startswith("---\n"): + md = "".join(md.split("---")[2:]) + gr.Markdown(md) + +def start( lock=True ): + ui.queue(max_size=8) + ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock) + +if __name__ == "__main__": + start() \ No newline at end of file diff --git a/scripts/run.sh b/scripts/run.sh old mode 100755 new mode 100644 diff --git a/setup.py b/setup.py index c719379..4c55a3b 100755 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ import subprocess - +import sys from pathlib import Path from datetime import datetime from setuptools import setup, find_packages @@ -8,7 +8,6 @@ def shell(*args): out = subprocess.check_output(args) return out.decode("ascii").strip() - def write_version(version_core, pre_release=True): if pre_release: time = shell("git", "log", "-1", "--format=%cd", "--date=iso") @@ -23,8 +22,7 @@ def write_version(version_core, pre_release=True): return version - -with open("README.md", "r", encoding="utf-8") as f: +with open("README.md", "r") as f: long_description = f.read() setup( @@ -37,17 +35,37 @@ setup( long_description=long_description, long_description_content_type="text/markdown", packages=find_packages(), - install_requires=[ + install_requires=( + # training backends + ["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else []) + + [ + # logging niceties "coloredlogs>=15.0.1", + "humanize>=4.4.0", + "matplotlib>=3.6.0", + "pandas>=1.5.0", + + # boiler plate niceties "diskcache>=5.4.0", "einops>=0.6.0", - "omegaconf==2.0.6", - "tqdm>=4.64.1", - "humanize>=4.4.0", + "tqdm", - "pandas>=1.5.0", + # HF bloat + "tokenizers", + "transformers", + "safetensors", + + # training bloat + "h5py", + "prodigyopt @ git+https://github.com/konstmish/prodigy", + + # practically the reason to use python + "numpy", "torch>=1.13.0", - "torchmetrics", + "torchmetrics", + + "simple_http_server", + "pillow" ], url="https://git.ecker.tech/mrq/resnet-classifier", )