updated framework to use the saner framework that mrq/vall-e uses these days

This commit is contained in:
mrq 2024-09-04 15:48:29 -05:00
parent 5cb28a210e
commit 5610bb3bb3
28 changed files with 3438 additions and 736 deletions

0
.gitignore vendored Executable file → Normal file
View File

0
LICENSE Executable file → Normal file
View File

40
README.md Executable file → Normal file
View File

@ -1,10 +1,10 @@
# Tentative Title For A ResNet-Based Image Classifier # Tentative Title For A ResNet-Based Image Classifier
This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/). This is a simple ResNet based image classifier for images, using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
## Premise ## Premise
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born. This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I faced.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`. This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
@ -16,44 +16,14 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
3. Install using `pip3 install -e ./image_classifier/`. 3. Install using `pip3 install -e ./image_classifier/`.
4. Train using `python3 -m image_classifier.train yaml='./data/config.yaml'`. 4. Train using `python3 -m image_classifier.train --yaml='./data/config.yaml'`.
5. Wait. 5. Wait.
## Inferencing ## Inferencing
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0` Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" --yaml="./data/config.yaml"`
### Continuous Usage ### Continuous Usage
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label. If you're looking to continuously classify images, use `python3 -m image_classifier --listen --port=7860 --yaml="./data/config.yaml"` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
## Known Issues
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
## Strawmen
>\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!!
I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests.
>\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]`
Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader).
Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there.
>\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`??
I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment.
>\> UGH!!! What are you talking about """specific images"""???
[ひみつ](https://files.catbox.moe/csuh49.webm)
>\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`???
:)

131
data/config.yaml Executable file → Normal file
View File

@ -1,85 +1,84 @@
dataset: weights_format: sft
training: [
"./data/images/"
]
validation: []
use_hdf5: False
workers: 0
cache: True
models: models:
_models: - name: "classifier"
- name: "classifier" tokens: 0
tokens: 0 len: 6
len: 6 dim: 512
resnet: 34
#loras:
#- name : "lora"
# rank: 128
# alpha: 128
# training: True
# rvq_levels: []
hyperparameters: hyperparameters:
batch_size: 256 batch_size: 256
gradient_accumulation_steps: 64 gradient_accumulation_steps: 1
gradient_clipping: 100 gradient_clipping: 1.0
warmup_steps: 10
optimizer: Prodigy
learning_rate: 1.0
torch_optimizer: True
optimizer: Adamw scheduler: "" # ScheduleFree
learning_rate: 1.0e-3 torch_scheduler: True
scheduler_type: ""
#scheduler_type: OneCycle
#scheduler_params:
# cycle_first_step_size: 10_000
# cycle_first_stair_count: 10_000
# cycle_second_step_size: 15_000
# cycle_second_stair_count: 15_000
# decay_step_size: 5_000
# cycle_min_lr: 2.5e-4 # 1.0e-5
# cycle_max_lr: 2.5e-4 # 1.0e-4
# decay_lr_rate: 0.0
# cycle_min_mom: 0.90
# cycle_max_mom: 0.99
# decay_mom_rate: 0.0
evaluation: evaluation:
batch_size: 32 batch_size: 64
frequency: 250 frequency: 100
size: 32 size: 64
steps: 300 steps: 450
temperature: 1.0 temperature: 0.0
trainer: trainer:
iterations: 100_000 iterations: 1_000_000
save_tag: step
save_on_oom: True
save_on_quit: True
save_frequency: 100 save_frequency: 100
keep_last_checkpoints: 32
aggressive_optimizations: False
check_for_oom: False
#load_tag: "9500"
#load_state_dict: True
#load_states: False
#strict_loading: False
#restart_step_count: True
gc_mode: None # "global_step" check_for_oom: False
gradient_checkpointing: True
weight_dtype: float32 weight_dtype: bfloat16
amp: True
backend: local backend: deepspeed
deepspeed: deepspeed:
zero_optimization_level: 0 inferencing: False
use_compression_training: True amp: False
inference: inference:
use_vocos: True backend: local
bitsandbytes: weight_dtype: bfloat16
enabled: false amp: True
optimizations:
injects: False
replace: True
linear: False
embedding: False
optimizers: True
bitsandbytes: False
dadaptation: False
bitnet: False
fp8: False
dataset:
use_hdf5: True
hdf5_flag: r
workers: 1
cache: True
training: [
"./data/images/"
]
validation: [
"./data/validation/"
]

View File

@ -12,35 +12,55 @@ def main():
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", action='store_true') parser.add_argument("--listen", action='store_true')
parser.add_argument("--port", type=int, default=9090) parser.add_argument("--port", type=int, default=9090)
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ckpt", type=Path, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument("--temp", type=float, default=1.0) parser.add_argument("--amp", action="store_true")
parser.add_argument("--device", default="cuda") parser.add_argument("--dtype", type=str, default=None)
parser.add_argument("--temp", type=float, default=0.0)
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device ) classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
if args.listen: if args.listen:
@route("/") @route("/")
def inference( b64, temperature=1.0 ): def inference( b64, temperature=args.temp ):
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
return { "answer": classifier.inference( image=image, temperature=args.temp ) } return { "answer": classifier.inference( image=image, temperature=temperature ) }
server.start(port=args.port) server.start(port=args.port)
else: else:
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--path", type=Path) parser.add_argument("--path", type=Path)
parser.add_argument("--base64", type=str) parser.add_argument("--base64", type=str)
parser.add_argument("--write", type=Path)
parser.add_argument("--temp", type=float, default=1.0) parser.add_argument("--temp", type=float, default=1.0)
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
images = []
if args.path: if args.path:
image = Image.open(args.path).convert('RGB') if args.path.is_dir():
for p in args.path.rglob("./*.jpg"):
image = Image.open(p).convert('RGB')
images.append(image)
for p in args.path.rglob("./*.png"):
image = Image.open(p).convert('RGB')
images.append(image)
else:
image = Image.open(args.path).convert('RGB')
images.append(image)
elif args.base64: elif args.base64:
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB") image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
images.append(image)
else: else:
raise "Specify a --path or --base64." raise "Specify a --path or --base64."
answer = classifier.inference( image=image, temperature=args.temp ) for image in images:
print("Answer:", answer) answer = classifier.inference( image=image, temperature=args.temp )
print("Answer:", answer)
if args.write:
args.write.mkdir(exist_ok=True)
image.save( args.write / f"{answer}.jpg")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -6,31 +6,61 @@ import os
import subprocess import subprocess
import sys import sys
import time import time
import argparse
from dataclasses import asdict, dataclass import yaml
from dataclasses import dataclass, field import random
import logging
from functools import cached_property, cache
from pathlib import Path
from omegaconf import OmegaConf
import torch import torch
import numpy as np
from dataclasses import asdict, dataclass, field
from functools import cached_property
from pathlib import Path
from .utils.distributed import world_size
def set_seed(seed=None):
if not seed:
seed = time.time()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@dataclass() @dataclass()
class _Config: class BaseConfig:
cfg_path: str | None = None yaml_path: str | None = None # path passed in through --yaml
@property @property
def relpath(self): def cfg_path(self):
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
@property
def rel_path(self):
return Path(self.cfg_path) return Path(self.cfg_path)
@property
def cache_dir(self):
return self.rel_path / ".cache"
@property
def data_dir(self):
return self.rel_path / "data"
@property
def metadata_dir(self):
return self.rel_path / "metadata"
@property @property
def ckpt_dir(self): def ckpt_dir(self):
return self.relpath / "ckpt" return self.rel_path / "ckpt"
@property @property
def log_dir(self): def log_dir(self):
return self.relpath / "logs" / str(self.start_time) return self.rel_path / "logs" / str(self.start_time)
@cached_property @cached_property
def start_time(self): def start_time(self):
@ -64,39 +94,28 @@ class _Config:
with open(path, "w") as f: with open(path, "w") as f:
f.write(self.dumps()) f.write(self.dumps())
@staticmethod
def _is_cfg_argv(s):
return "=" in s and "--" not in s
@classmethod @classmethod
def from_yaml( cls, yaml_path ): def from_yaml( cls, yaml_path ):
return cls.from_cli( [f'yaml="{yaml_path}"'] ) state = {}
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
state.setdefault("yaml_path", yaml_path)
return cls(**state)
@classmethod @classmethod
def from_cli(cls, args=sys.argv): def from_cli(cls, args=sys.argv):
cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)]) # legacy support for yaml=`` format
for i, arg in enumerate(args):
if arg.startswith("yaml"):
args[i] = f'--{arg}'
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse. parser = argparse.ArgumentParser(allow_abbrev=False)
sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)] parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
args, unknown = parser.parse_known_args(args=args)
if cli_cfg.get("help"): if args.yaml:
print(f"Configurable hyperparameters with their default values:") return cls.from_yaml( args.yaml )
print(json.dumps(asdict(cls()), indent=2, default=str))
exit()
if "yaml" in cli_cfg: return cls(**{})
yaml_cfg = OmegaConf.load(cli_cfg.yaml)
yaml_path = Path(cli_cfg.yaml).absolute()
cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1])
cfg_path = cfg_path.with_suffix("")
cfg_path = f'./{cfg_path}'
yaml_cfg.setdefault("cfg_path", cfg_path)
cli_cfg.pop("yaml")
else:
yaml_cfg = {}
merged = OmegaConf.merge(yaml_cfg, cli_cfg)
return cls(**dict(merged))
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@ -106,104 +125,195 @@ class _Config:
@dataclass() @dataclass()
class Dataset: class Dataset:
training: list[Path] = field(default_factory=lambda: []) training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
validation: list[Path] = field(default_factory=lambda: []) validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
temp: list[Path] = field(default_factory=lambda: [])
# de-implemented, because the data isn't that large to facilitate HDF5
hdf5_name: str = "data.h5"
use_hdf5: bool = False
workers: int = 8 hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
cache: bool = True use_hdf5: bool = False # whether to load from an HDF5 dataset
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
workers: int = 8 # number of dataloader workers to spawn
cache: bool = True # use diskcache to cache the dataset
# I really need to clean this up
@dataclass() @dataclass()
class Model: class Model:
name: str = "" name: str = "classifier"
tokens: int = 0 # number of token types tokens: int = 0 # number of token types
len: int = 1 # how long a sequence can be len: int = 1 # how long a sequence can be
dim: int = 512 dim: int = 512
resnet: int = 18
width: int = 300
height: int = 80
version: int = 1
training: bool = True
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
@property @property
def full_name(self): def full_name(self):
return self.name return self.name
@dataclass()
class Models:
_models: list[Model] = field(default_factory=lambda: [
Model(name="captcha"),
])
def get(self, name=None): def get(self, name=None):
if not name: return [ self ] if not name or self.name == name else []
return [ Model(**model) for model in self._models ]
def loss_factor(self, k):
return self.loss_factors[k] if k in self.loss_factors else 1.0
for model in self._models: @property
if model.name == name: # required for fp8 as the lengths needs to be divisible by 8
return model def input_alignment(self):
return 8 if cfg.optimizations.fp8 else 0
raise ValueError @property
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing
@property
def gradient_checkpointing(self):
return cfg.trainer.gradient_checkpointing
@property
def lora_policy(self):
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
exclude = []
if self.arch_type == "llama":
include = ["self_attn", "mlp"] # target only the attention + mlp
exclude = ["self_attn.k_proj"] # common literature says to ignore it
if self.arch_type == "retnet":
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
return dict(include=include, exclude=exclude)
# should be renamed to Adapters
@dataclass()
class LoRA:
name: str = "lora" # vanity name
# to-do: find sane default values
rank: int = 128 # rank for the LoRA
alpha: int = 128 # rank for the LoRA
training: bool = True #
embeddings: bool = False # train the embedding too
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
@property
def full_name(self):
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
return "-".join(name)
# actually not needed anymore
def active_level( self, level ):
if not self.rvq_levels:
return True
return level in self.rvq_levels
@dataclass() @dataclass()
class Hyperparameters: class Hyperparameters:
batch_size: int = 8 batch_size: int = 8 # number of samples per training batch
gradient_accumulation_steps: int = 32 gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
gradient_clipping: int = 100 # to be implemented in the local backend gradient_clipping: int | float = 10 # largest size a gradient norm can be
optimizer: str = "Adamw" optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
learning_rate: float = 3.25e-4 optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
scheduler_type: str = "" # to be implemented in the local backend scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
scheduler_params: dict = field(default_factory=lambda: {}) scheduler_type: str = "" # deprecated
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
autotune: bool = False # to do deepspeed's autotuning
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
@dataclass() @dataclass()
class Evaluation: class Evaluation:
batch_size: int = 64 batch_size: int = 64 # number of samples per batch during eval / val
frequency: int = 250 frequency: int = 250 # do eval / val every X iterations
size: int = 64 size: int = 64 # number of samples to generate during eval / val
steps: int = 500 steps: int = 500
temperature: float = 1.0 temperature: float = 1.0 # AR temp for inferencing
load_disabled_engines: bool = True # see the other load_disabled_engines
@dataclass() @dataclass()
class DeepSpeed: class DeepSpeed:
zero_optimization_level: int = 0 zero_optimization_level: int = 0 # doesn't seem to work
use_compression_training: bool = False use_compression_training: bool = False # cope
compression_bits: int = 8 # cope
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
def get_ds_cfg(self, model): config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
weights = [ name[0] for name in model.named_parameters() ]
bits = 8
scheduler_params = {} @cached_property
for k in cfg.hyperparameters.scheduler_params: def ds_cfg(self):
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] optimizer_params = cfg.hyperparameters.optimizer_params
if 'lr' not in optimizer_params:
optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: scheduler_params = cfg.hyperparameters.scheduler_params
if 'warmup_num_steps' not in scheduler_params:
scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
if 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations scheduler_params['total_num_steps'] = cfg.trainer.iterations
autotune_params = cfg.hyperparameters.autotune_params
if "enabled" not in autotune_params:
autotune_params['enabled'] = True
if "results_dir" not in autotune_params:
autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" )
if "exps_dir" not in autotune_params:
autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" )
# DeepSpeed fp16 is incompatible with its AMP
if cfg.trainer.weight_dtype.lower() == "float16":
self.amp = False
# disable local AMP
if self.amp:
cfg.trainer.amp = False
ds_cfg = { ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
"optimizer": { "optimizer": {
"type": cfg.hyperparameters.optimizer, "type": cfg.hyperparameters.optimizer,
"params": { "params": optimizer_params,
"lr": cfg.hyperparameters.learning_rate, } if not cfg.hyperparameters.torch_optimizer else None,
}
},
"scheduler": { "scheduler": {
"type": cfg.hyperparameters.scheduler_type, "type": cfg.hyperparameters.scheduler,
"params": scheduler_params, "params": scheduler_params,
} if cfg.hyperparameters.scheduler_type != "" else None, } if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping, "gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": { "fp16": {
"enabled": True, "enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": True, "auto_cast": True, # ???
} if cfg.trainer.weight_dtype.lower() == "float16" else None, "loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
}, },
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
},
"amp": {
"enabled": self.amp,
},
"autotuning": autotune_params if cfg.hyperparameters.autotune else None,
"compression_training": { "compression_training": {
"weight_quantization": { "weight_quantization": {
"shared_parameters":{ "shared_parameters":{
@ -214,7 +324,7 @@ class DeepSpeed:
"quantize_verbose": True, "quantize_verbose": True,
"quantization_type": "symmetric", "quantization_type": "symmetric",
"rounding": "nearest", "rounding": "nearest",
"quantize_weight_in_forward": True, "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{ "fp16_mixed_quantize":{
"enabled": False, "enabled": False,
"quantize_change_ratio": 1 "quantize_change_ratio": 1
@ -223,30 +333,38 @@ class DeepSpeed:
"different_groups": { "different_groups": {
"wq1": { "wq1": {
"params": { "params": {
"start_bits": bits, "start_bits": self.compression_bits,
"target_bits": bits, "target_bits": self.compression_bits,
"quantization_period": 0 "quantization_period": 0
}, },
"modules": weights "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
} }
} }
}, },
"activation_quantization": { "activation_quantization": {
"shared_parameters":{ "shared_parameters":{
"enabled": True, "enabled": True,
"quantizer_kernel": True,
"schedule_offset": 0,
"quantize_groups": 64,
"quantize_verbose": True,
"quantization_type": "symmetric", "quantization_type": "symmetric",
"range_calibration": "dynamic", "rounding": "nearest",
"schedule_offset": 0 "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
}
}, },
"different_groups": { "different_groups": {
"aq1": { "aq1": {
"params": { "params": {
"bits": bits "bits": self.compression_bits,
}, },
"modules": weights "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
} }
} }
} },
} if self.use_compression_training else None, } if self.use_compression_training else None,
"zero_optimization": { "zero_optimization": {
"stage": self.zero_optimization_level, "stage": self.zero_optimization_level,
@ -264,7 +382,10 @@ class DeepSpeed:
"offload_param": { "offload_param": {
"device": "cpu", "device": "cpu",
"pin_memory": True "pin_memory": True
} },
"zero_quantized_weights": self.use_compression_training,
"zero_hpz_partition_size": world_size(),
"zero_quantized_gradients": self.use_compression_training,
} if self.zero_optimization_level > 0 else None, } if self.zero_optimization_level > 0 else None,
"comms_logger": { "comms_logger": {
"enabled": False "enabled": False
@ -275,113 +396,314 @@ class DeepSpeed:
for k in null_keys: for k in null_keys:
del ds_cfg[k] del ds_cfg[k]
if os.path.exists("./config/ds_config.json"): if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8"))) ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
else:
ds_cfg.update(self.config)
return ds_cfg return ds_cfg
@dataclass() @dataclass()
class Trainer: class Trainer:
iterations: int = 100_000 iterations: int = 1_000_000 # maximum iterations to train
save_tag: str = "step" save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
load_tag: str | None = None load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
save_on_oom: bool = True save_on_oom: bool = True # save if an OOM error is raised
save_on_quit: bool = True save_on_quit: bool = True # save when quitting training
save_frequency: int = 100
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
save_frequency: int = 100 # frequency to save every X iterations
load_state_dict: bool = False keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
load_states: bool = True
strict_loading: bool = True
restart_step_count: bool = False
aggressive_optimizations: bool = False load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
check_for_oom: bool = True load_states: bool = True #
strict_loading: bool = False # sets strict_loading=True when loading the state dict
load_module_only: bool = False #
restart_step_count: bool = False # clears the training stats when loading a checkpoint
resize_modules: bool = False # automatically resizes
gc_mode: str | None = None activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
weight_dtype: str = "float16" aggressive_optimizations: bool = False # deprecated
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
backend: str = "deepspeed" weight_dtype: str = "float16" # dtype to have the model under
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) amp: bool = False # automatic mixed precision
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
@cached_property @cached_property
def dtype(self): def dtype(self):
if self.weight_dtype == "float16": if self.weight_dtype == "float16":
return torch.float16 return torch.float16
if cfg.trainer.weight_dtype == "bfloat16": if self.weight_dtype == "bfloat16":
return torch.bfloat16 return torch.bfloat16
if self.weight_dtype == "float8_e5m2":
return torch.float8_e5m2
if self.weight_dtype == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32 return torch.float32
@cached_property
def scale_loss(self):
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
return self.dtype == torch.float16
@dataclass() @dataclass()
class Inference: class Inference:
use_vocos: bool = True # artifact from the VALL-E trainer backend: str = "local" # backend to use when inferencing
weight_dtype: str = "float32" # dtype to load the model under
amp: bool = False # automatic mixed precision during inferencing
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
@property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
if self.weight_dtype == "int8":
return torch.int8
if self.weight_dtype == "float8_e5m2":
return torch.float8_e5m2
if self.weight_dtype == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32
@dataclass() @dataclass()
class BitsAndBytes: class Optimizations:
enabled: bool = False injects: bool = False # overwrites default torch classes (not recommended)
injects: bool = False replace: bool = False # replaces modules in place with the optimized version (recommended)
compile: bool | str = False # runs torch.compile on the model
linear: bool = False linear: bool = True # inject/replace linear for BnB
embedding: bool = False embedding: bool = True # inject/replace embedding for BnB
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitsandbytes: bool = False # use bitsandbytes
dadaptation: bool = False # use dadaptation optimizer
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
model_offloading: dict | None = None # automatically splits the model over a list of devices
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
tensorrt: bool = False
@dataclass() @dataclass()
class Config(_Config): class Config(BaseConfig):
device: str = "cuda" device: str = "cuda" # target device
mode: str = "training" # "inferencing"
experimental: bool = False # Debug flag, unused now
dataset: Dataset = field(default_factory=lambda: Dataset) dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models) models: dict | list | None = field(default_factory=lambda: [])
loras: dict | list | None = field(default_factory=lambda: [])
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation) evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer) trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference) inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
def get_device(self): tokenizer: str | None = None # tokenizer class
return torch.cuda.current_device() if self.device == "cuda" else self.device tokenizer_path: str = "./tokenizer.json" # tokenizer path
weights_format: str = "pth" # "pth" | "sft"
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
@property @property
def cache_dir(self): def model(self):
return ".cache" / self.relpath for i, model in enumerate(self.models):
if model.training:
return model
return self.models[0] if len(self.models) > 0 else None
# should be renamed to adapters
@property
def lora(self):
for i, lora in enumerate(self.loras):
if lora.training:
return lora
return self.loras[0] if len(self.loras) > 0 else None
@property
def distributed(self):
return world_size() > 1
@cached_property @cached_property
def diskcache(self): def diskcache(self):
if self.dataset.cache: if self.yaml_path is not None and self.dataset.cache:
return diskcache.Cache(self.cache_dir).memoize return diskcache.Cache(self.cache_dir).memoize
return lambda: lambda x: x return lambda: lambda x: x
# I don't remember why this is needed
def load_yaml( self, config_path ): def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path ) tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__) self.__dict__.update(tmp.__dict__)
def load_hdf5( self, write=False ):
if hasattr(self, 'hdf5'):
self.hdf5.close()
if self.distributed:
self.dataset.hdf5_flag = "r"
try:
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
self.dataset.use_hdf5 = False
# to-do: prune unused keys
def format( self, training=True ):
if isinstance(self.dataset, type):
self.dataset = dict()
if isinstance(self.models, type):
self.models = dict()
if isinstance(self.loras, type):
self.loras = dict()
if isinstance(self.hyperparameters, type):
self.hyperparameters = dict()
if isinstance(self.evaluation, type):
self.evaluation = dict()
if isinstance(self.trainer, type):
self.trainer = dict()
if isinstance(self.inference, type):
self.inference = dict()
if isinstance(self.optimizations, type):
self.optimizations = dict()
self.dataset = Dataset(**self.dataset)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.models = [ Model(**model) for model in self.models ]
self.loras = [ LoRA(**lora) for lora in self.loras ]
if not self.models:
self.models = [ Model() ]
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer)
if not isinstance(self.trainer.deepspeed, type):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
self.inference = Inference(**self.inference)
if self.bitsandbytes is not None:
self.optimizations = Optimizations(**self.bitsandbytes)
else:
self.optimizations = Optimizations(**self.optimizations)
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
self.hyperparameters.scheduler_type = ""
# do not combine the two
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
self.hyperparameters.scheduler = ""
if self.hyperparameters.scheduler == "":
self.hyperparameters.torch_scheduler = True
if self.trainer.backend == "local" and self.distributed:
self.trainer.ddp = True
if self.trainer.activation_checkpointing is not None:
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
if not training:
self.dataset.use_hdf5 = False
# load our HDF5 file if requested here
if self.dataset.use_hdf5:
self.load_hdf5()
# load tokenizer
if cfg.tokenizer == "naive":
cfg.tokenizer = NaiveTokenizer()
else:
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
if tokenizer_path and not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer_path
if tokenizer_path and tokenizer_path.exists():
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
else:
cfg.tokenizer = NaiveTokenizer()
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
pass
# Preserves the old behavior
class NaiveTokenizer:
def get_vocab( self ):
"""
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
"""
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
@cached_property
def _bos_token( self ):
return self.get_vocab()["<s>"]
@cached_property
def _eos_token( self ):
return self.get_vocab()["</s>"]
def encode( self, s ):
symmap = self.get_vocab()
s = s.replace("O", "0")
s = [f"<s>"] + [ p if p in symmap else " " for p in s ] + [f"</s>"]
return [*map(symmap.get, s)]
_logger = logging.getLogger(__name__)
cfg = Config.from_cli() cfg = Config.from_cli()
# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves # some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
cfg.dataset = Dataset(**cfg.dataset) try:
cfg.models = Models(**cfg.models) cfg.format()
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters) except Exception as e:
cfg.evaluation = Evaluation(**cfg.evaluation) _logger.error(f"Error while parsing config YAML: {str(e)}")
cfg.trainer = Trainer(**cfg.trainer) raise e # throw an error because I'm tired of silent errors messing things up for me
cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
# cached_property stopped working...
if cfg.dataset.use_hdf5:
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
if __name__ == "__main__": if __name__ == "__main__":
print(cfg) print(cfg)

View File

@ -1,16 +1,19 @@
# todo: clean this mess up # todo: clean this mess up
import copy import copy
# import h5py import h5py
import json import json
import logging import logging
#import numpy as np import numpy as np
import os import os
import random import random
import torch import torch
import math import itertools
from .config import cfg from .config import cfg
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from .utils.io import torch_save, torch_load
from collections import defaultdict from collections import defaultdict
from functools import cache, cached_property from functools import cache, cached_property
@ -20,23 +23,57 @@ from typing import Any
from torch import Tensor from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from PIL import Image, ImageDraw
import torchvision.transforms as transforms import torchvision.transforms as transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from PIL import Image
# torch.multiprocessing.set_sharing_strategy("file_system") # torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@cache # to-do: clean up this symmap mess
def get_symmap(): def get_symmap():
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 } return cfg.tokenizer.get_vocab()
@cache def tokenize( s ):
def _get_symbols( content ): if isinstance( s, list ):
content = content.replace("O", "0") s = "".join( s )
return [f"<s>"] + [ p for p in content ] + [f"</s>"] return cfg.tokenizer.encode( s )
"""
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_hdf5_path(path):
# to-do: better validation
return str(path)
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
key = f"/{type}/{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, validate=False ):
if isinstance(path, str):
path = Path(path)
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
groups[fn(e)].append(e)
groups = {k: groups[k] for k in sorted(groups)}
for interleaved in zip_longest(*groups.values()):
for value in interleaved:
if value is not None:
yield value
"""
class Dataset(_Dataset): class Dataset(_Dataset):
def __init__( def __init__(
@ -44,43 +81,90 @@ class Dataset(_Dataset):
paths, paths,
width=300, width=300,
height=80, height=80,
stacks=0,
symmap=get_symmap(), symmap=get_symmap(),
training=False, training=False,
): ):
super().__init__() super().__init__()
self._head = None self._head = None
self.sampler = None
self.paths = paths
self.width = width self.width = width
self.height = height self.height = height
self.stacks = stacks
self.paths = paths
self.image_dtype = cfg.trainer.dtype
self.symmap = symmap self.symmap = symmap
self.training = training self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.transform = transforms.Compose([ self.transform = transforms.Compose([
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]) ])
@cached_property # to-do: do not do validation if there's nothing in the validation
def symbols(self): # this just makes it be happy
return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths])) if len(self.dataset) == 0:
self.dataset = cfg.dataset.training
# split dataset accordingly per GPU
if cfg.distributed and self.training:
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
@cached_property
def sampler_state_dict_path(self):
return cfg.rel_path / f"sampler.rank{global_rank()}.pt"
def save_state_dict(self, path = None):
"""
if path is None:
path = self.sampler_state_dict_path
if self.sampler is not None:
state_dict = self.sampler.get_state()
elif self.samplers is not None:
state_dict = {
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
}
torch_save(state_dict, path)
"""
return
def load_state_dict(self, path = None):
"""
if path is None:
path = self.sampler_state_dict_path
if not path.exists():
return
state_dict = torch_load(path)
if self.sampler is not None:
state_dict = self.sampler.set_state(state_dict)
else:
for name, sampler in state_dict["samplers"].items():
if name not in self.samplers:
continue
self.samplers[name].set_state( sampler )
"""
return
def __getitem__(self, index): def __getitem__(self, index):
path = self.paths[index] path = self.paths[index]
tokens = tokenize( path.stem.upper() )
text = torch.tensor( tokens ).to(dtype=torch.uint8)
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here image = Image.open(path).convert('RGB')
try: width, height = image.size
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
except Exception as e:
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
raise e
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
return dict( return dict(
index=index, index=index,
@ -98,11 +182,6 @@ class Dataset(_Dataset):
def __len__(self): def __len__(self):
return min(len(self.paths), self._head or len(self.paths)) return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self):
self.text = self.text.pin_memory()
self.image = self.image.pin_memory()
return self
def collate_fn(samples: list[dict]): def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]} batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
@ -111,21 +190,28 @@ def collate_fn(samples: list[dict]):
def _seed_worker(worker_id): def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32 worker_seed = torch.initial_seed() % 2**32
#np.random.seed(worker_seed) np.random.seed(worker_seed)
random.seed(worker_seed) random.seed(worker_seed)
def _create_dataloader(dataset, training): def _create_dataloader(dataset, training):
kwargs = dict(
shuffle=True,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
drop_last=training,
sampler=dataset.sampler if training else None,
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
batch_sampler=dataset.sampler,
)
return DataLoader( return DataLoader(
dataset=dataset, dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=True, # training
drop_last=training,
num_workers=cfg.dataset.workers, num_workers=cfg.dataset.workers,
collate_fn=collate_fn, collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 0, persistent_workers=cfg.dataset.workers > 1,
pin_memory=False, # True, pin_memory=False,
worker_init_fn=_seed_worker, worker_init_fn=_seed_worker,
**kwargs,
) )
def _load_train_val_paths( val_ratio=0.1 ): def _load_train_val_paths( val_ratio=0.1 ):
@ -133,8 +219,8 @@ def _load_train_val_paths( val_ratio=0.1 ):
train_paths = [] train_paths = []
val_paths = [] val_paths = []
print(cfg.dataset.training)
for data_dir in cfg.dataset.training: for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png")) paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0: if len(paths) > 0:
@ -146,12 +232,13 @@ def _load_train_val_paths( val_ratio=0.1 ):
val_len = math.floor(len(train_paths) * val_ratio) val_len = math.floor(len(train_paths) * val_ratio)
train_len = math.floor(len(train_paths) * (1 - val_ratio)) train_len = math.floor(len(train_paths) * (1 - val_ratio))
print(val_len, train_len)
val_paths = train_paths[:-val_len] val_paths = train_paths[:-val_len]
train_paths = train_paths[:train_len] train_paths = train_paths[:train_len]
else: else:
paths = []
for data_dir in cfg.dataset.validation: for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png")) paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0: if len(paths) > 0:
@ -169,7 +256,6 @@ def _load_train_val_paths( val_ratio=0.1 ):
return train_paths, val_paths return train_paths, val_paths
@cfg.diskcache()
def create_datasets(): def create_datasets():
train_paths, val_paths = _load_train_val_paths() train_paths, val_paths = _load_train_val_paths()
@ -187,10 +273,10 @@ def create_datasets():
return train_dataset, val_dataset return train_dataset, val_dataset
def create_train_val_dataloader(): def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets() train_dataset, val_dataset = create_datasets()
# deepcopy is slow
subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size) subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.training_(False) subtrain_dataset.training_(False)
@ -200,8 +286,6 @@ def create_train_val_dataloader():
subtrain_dl = _create_dataloader(subtrain_dataset, training=False) subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.symmap)) _logger.info(str(train_dataset.symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.") _logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
@ -210,11 +294,305 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl return train_dl, subtrain_dl, val_dl
# parse dataset into better to sample metadata
""" """
if __name__ == "__main__": def create_dataset_metadata( skip_existing=True ):
create_dataset_hdf5() symmap = get_symmap()
root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir)
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
sample = train_dl.dataset[0]
print(sample) def add( dir, type="training", audios=True, texts=True ):
""" name = str(dir)
name = name.replace(root, "")
speaker_name = name
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
try:
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
except Exception as e:
metadata = {}
if not os.path.isdir(f'{root}/{name}/'):
return
# tqdm.write(f'{root}/{name}')
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
wrote = False
for id in tqdm(ids, desc=f"Processing {name}"):
try:
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
if audios and not quant_path.exists():
continue
key = f'{type}/{speaker_name}/{id}'
if skip_existing and id in metadata:
continue
wrote = True
if id not in metadata:
metadata[id] = {}
utterance_metadata = {}
if audios:
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np.load(quant_path, allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
if "text" in dac["metadata"]:
utterance_metadata["text"] = dac["metadata"]["text"]
if "phonemes" in dac["metadata"]:
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
if "language" in dac["metadata"]:
utterance_metadata["language"] = dac["metadata"]["language"]
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
for k, v in utterance_metadata.items():
metadata[id][k] = v
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
if wrote:
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
# parse yaml to create an hdf5 file
def create_dataset_hdf5( skip_existing=True ):
cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True)
hf = cfg.hdf5
symmap = get_symmap()
root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir)
def add( dir, type="training", audios=True, texts=True ):
name = str(dir)
name = name.replace(root, "")
# yucky
speaker_name = name
if "LibriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
if not os.path.isdir(f'{root}/{name}/'):
return
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
for id in tqdm(ids, desc=f"Processing {name}"):
try:
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not quant_exists:
continue
key = f'{type}/{speaker_name}/{id}'
if skip_existing and key in hf:
continue
group = hf.create_group(key) if key not in hf else hf[key]
if id not in metadata:
metadata[id] = {}
utterance_metadata = {}
# audio
if audios:
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
if "text" in dac["metadata"]:
utterance_metadata["text"] = dac["metadata"]["text"]
if "phonemes" in dac["metadata"]:
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
if "language" in dac["metadata"]:
utterance_metadata["language"] = dac["metadata"]["language"]
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
if "audio" not in group:
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
# text
if texts:
if not utterance_metadata and text_exists:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
phn = "".join(utterance_metadata["phonemes"])
phn = cfg.tokenizer.encode(phn)
phn = np.array(phn).astype(np.uint8)
if "text" not in group:
group.create_dataset('text', data=phn, compression='lzf')
for k, v in utterance_metadata.items():
group.attrs[k] = v
metadata[id][k] = v
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
# write symmap
if "symmap" in hf:
del hf['symmap']
hf.create_dataset('symmap', data=json.dumps(symmap))
hf.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--action", type=str)
parser.add_argument("--tasks", type=str)
args, unknown = parser.parse_known_args()
task = args.action
cfg.dataset.workers = 1
if args.action == "hdf5":
create_dataset_hdf5()
elif args.action == "list-dataset":
dataset = []
for group in os.listdir(cfg.data_dir):
for name in os.listdir(cfg.data_dir / group):
if len(os.listdir(cfg.data_dir / group / name)) == 0:
continue
dataset.append(f'{group}/{name}')
_logger.info(json.dumps(dataset))
elif args.action == "metadata":
create_dataset_metadata()
elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
for k, v in samples.items():
for i in range(len(v)):
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e:
_logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items():
for i in range(len(v)):
_logger.info(f'{k}[{i}]: {v[i]}')
elif args.action == "validate":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
missing = set()
for i in range(len( train_dl.dataset )):
batch = train_dl.dataset[i]
text = batch['text']
phonemes = batch['metadata']['phonemes']
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
for i, token in enumerate(decoded):
if token != "<unk>":
continue
phone = phonemes[i]
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
missing |= set([phone])
_logger.info( f"Missing tokens: {missing}" )
elif args.action == "tasks":
index = 0
cfg.dataset.tasks_list = args.tasks.split(",")
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
batch = next(iter(train_dl))
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
if task not in cfg.dataset.tasks_list:
continue
_logger.info( f'{text} {task} {cfg.model.resp_levels}')
_logger.info( f'{proms.shape} {resps.shape}' )
tokens = 0
tokens += sum([ text.shape[0] for text in batch["text"] ])
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
_logger.info( f'{tokens}' )
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
break
"""

View File

@ -1,6 +1,6 @@
from ..config import cfg from ..config import cfg
from ..utils.distributed import fix_unset_envs from ..utils.distributed import fix_unset_envs, ddp_model
fix_unset_envs() fix_unset_envs()
if cfg.trainer.backend == "deepspeed": if cfg.trainer.backend == "deepspeed":
@ -8,4 +8,211 @@ if cfg.trainer.backend == "deepspeed":
elif cfg.trainer.backend == "local": elif cfg.trainer.backend == "local":
from .base import Engine from .base import Engine
from .base import Engines, TrainFeeder, default_feeder from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models, get_model
from ..utils import wrapper as ml
from ..utils.io import torch_save, torch_load, pick_path
from ..models.lora import apply_lora, lora_load_state_dict
import torch
import re
import logging
_logger = logging.getLogger(__name__)
deepspeed_available = False
try:
import deepspeed
deepspeed_available = True
except Exception as e:
pass
from functools import cache
@cache
def load_engines(training=True, **model_kwargs):
models = get_models(cfg.models, training=training, **model_kwargs)
engines = dict()
for name, model in models.items():
state = None
stats = None
lora = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
# to handle the issue of training with deepspeed, but inferencing with local
if checkpoint_path.exists() and backend == "local":
tag = open(checkpoint_path).read()
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
loads_state_dict = True
# load state early
if loads_state_dict:
state = torch_load(load_path, device=cfg.device)
# check if config is defined in state, and re-initialize the model
if "config" in state and False:
_logger.warning("Model config definition in weights, re-loading...")
config_state = state["config"]
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
hyper_config = model.config
optimizer = None
lr_scheduler = None
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
ddp = cfg.trainer.ddp
engine_class = LocalEngine if backend == "local" else Engine
# apply model replacers
if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model.model = ml.replace_embedding( model.model )
for lora in cfg.loras:
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
if inferencing:
model.config.training = False
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
optimizer_class = None
scheduler_class = None
params = {
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw":
params["betas"] = (0.9, 0.96)
params["eps"] = 1e-07
params["weight_decay"] = 0.01
# for dadaptation since it has Adam only
if ml.AdamW == ml.Adam:
params["decouple"] = True
optimizer_class = ml.AdamW
elif cfg.hyperparameters.optimizer.lower() == "sgd":
optimizer = ml.SGD
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
optimizer_class = ml.Prodigy
params['d_coef'] = params['lr']
params['lr'] = 1.0
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
optimizer_class = ml.Adagrad
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
**params,
)
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
if cfg.hyperparameters.optimizer.lower() == "adamw":
scheduler_class = ml.schedulefree.AdamWScheduleFree
elif cfg.hyperparameters.optimizer.lower() == "sgd":
scheduler_class = ml.schedulefree.SGDScheduleFree
else:
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
optimizer = scheduler_class(
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
lr = params['lr'],
warmup_steps = cfg.hyperparameters.warmup_steps
)
"""
# set up our LR scheduler here
"""
if inferencing:
optimizer = None
lr_scheduler = None
# load state dict if requested / required
if loads_state_dict:
# state dict is not just the module, extract the extra trainer details
if "stats" in state:
stats = state["stats"]
# do not load stats if we're training a LoRA
if cfg.lora is not None or cfg.trainer.restart_step_count:
stats = None
if "module" in state:
state = state["module"]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# load lora weights if exists
if cfg.lora is not None:
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if lora_path.exists():
_logger.info( f"Loaded LoRA state dict: {lora_path}" )
state = torch_load(lora_path, device=cfg.device)
state = state['lora' if 'lora' in state else 'module']
lora_load_state_dict( model, state )
# wrap if DDP is requested
if ddp:
model = ddp_model(model)
# wrap optimization class
elif cfg.optimizations.compile:
model = ml.compile_model(model, backend=cfg.optimizations.compile)
# deepspeed inferencing
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = LocalEngine
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
# use base engine if requested
engines[name] = engine_class(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
hyper_config=hyper_config,
stats=stats
)
engines = Engines(engines)
engines.setup()
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
if not cfg.trainer.load_state_dict:
engines.load_checkpoint(training=not inferencing)
# freeze requested params
for name, engine in engines.items():
engine.freeze(freeze_all=False)
# split models over requested devices
if cfg.optimizations.model_offloading:
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
return engines

View File

@ -28,7 +28,9 @@ def default_feeder(engine, batch):
from ..config import cfg from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
from ..utils.io import torch_save, torch_load
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
import logging import logging
import time import time
@ -39,40 +41,65 @@ import os
from torch import Tensor from torch import Tensor
from torch.distributed import all_reduce from torch.distributed import all_reduce
from typing import Any, Protocol from typing import Any, Protocol
from functools import cached_property
from .base import TrainFeeder from .base import TrainFeeder
from ..utils import wrapper as ml
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local": if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
def _nop(): init_distributed(torch.distributed.init_process_group)
...
fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group
init_distributed(fn)
# A very naive engine implementation using barebones PyTorch # A very naive engine implementation using barebones PyTorch
# to-do: implement lr_sheduling
class Engine(): class Engine():
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype) if 'hyper_config' in kwargs:
self.hyper_config = kwargs['hyper_config']
kwargs.pop("hyper_config")
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = 0 self.global_steps = kwargs.pop("global_steps", 0)
self.micro_steps = 0 self.micro_steps = kwargs.pop("micro_steps", 0)
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps self.global_samples = kwargs.pop("global_samples", 0)
self.tokens_processed = kwargs.pop("tokens_processed", 0)
def freeze(self): self._frozen_params = set()
for p in self.module.parameters():
if p.requires_grad: self.max_nan_losses = 8
p.requires_grad_(False) self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
self._frozen_params.add(p)
self.current_batch_size = 0
self._global_grad_norm = None
def freeze(self, freeze_all=True):
# set to freeze
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self): def unfreeze(self):
for p in self._frozen_params: for p in self._frozen_params:
p.requires_grad_(True) p.requires_grad_(True)
self._frozen_params.clear() self._frozen_params.clear()
@property
def _training(self):
if not hasattr(self, "hyper_config"):
return True
return self.hyper_config.training
@property @property
def global_step(self): def global_step(self):
return self.global_steps return self.global_steps
@ -81,8 +108,17 @@ class Engine():
def micro_step(self): def micro_step(self):
return self.micro_steps return self.micro_steps
def train_batch_size(self): @property
return cfg.hyperparameters.batch_size def batch_size(self):
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
@property
def gradient_accumulation_steps(self):
return cfg.hyperparameters.gradient_accumulation_steps
@property
def gradient_clipping(self):
return cfg.hyperparameters.gradient_clipping
def gather_attribute(self, *args, **kwargs): def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs) return gather_attribute(self.module, *args, **kwargs)
@ -91,42 +127,74 @@ class Engine():
return dispatch_attribute(self.module, *args, **kwargs) return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ): def save_checkpoint(self, save_dir, tag ):
save_path = save_dir / tag / "state.pth" if is_global_leader():
save_path.parent.mkdir(parents=True, exist_ok=True) module = self.module.state_dict()
torch.save({
"global_step": self.global_step,
"micro_step": self.micro_step,
"module": self.module.state_dict(),
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
}, save_path)
open(save_dir / "latest", 'w').write( tag ) # if training lora
# this is a separate path to override saving the weights
lora = None
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_dir = cfg.ckpt_dir / cfg.lora.full_name
save_path = save_dir / tag / f"state.{cfg.weights_format}"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch_save({
"module": module,
"lora": lora,
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
"stats": {
"global_step": self.global_step,
"micro_step": self.micro_step,
"global_samples": self.global_samples,
"tokens_processed": self.tokens_processed,
}
}, save_path)
open(save_dir / "latest", 'w').write( tag )
torch.distributed.barrier()
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
if tag is None: if tag is None:
tag_path = load_dir / "latest" tag_path = load_dir / "latest"
if not tag_path.exists(): if not tag_path.exists():
return return
tag = open(tag_path).read() tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth" load_path = load_dir / tag / f"state.{cfg.weights_format}"
if not load_path.exists(): if not load_path.exists():
return return
state = torch.load(load_path) state = torch_load(load_path, device=cfg.device)
self.global_steps = state['global_step']
self.micro_steps = state['micro_step'] self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
self.module.load_state_dict(state['module']) self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states: if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer']) self.optimizer.load_state_dict(state['optimizer']) #, device=cfg.device)
if load_lr_scheduler_states: if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler']) self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device)
if 'lora' in state:
lora_load_state_dict( self.module, state['lora'] )
def eval(self): def eval(self):
return self.module.eval() return self.module.eval()
@ -136,46 +204,80 @@ class Engine():
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs) self.module = self.module.to(*args, **kwargs)
return self.module if self.optimizer:
self.optimizer = self.optimizer.to(*args, **kwargs)
return self
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
@cached_property
def device(self):
return next(self.module.parameters()).device
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs) return self.module.forward(*args, **kwargs)
def backward(self, loss): def backward(self, loss):
if self.loss_scaler is not None:
return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward()
return (loss / self.gradient_accumulation_steps).backward() return (loss / self.gradient_accumulation_steps).backward()
def step(self): def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1): with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1 self.micro_steps += 1
self.global_samples += self.batch_size
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0: if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
self.global_steps += 1 self.global_steps += 1
self.optimizer.step() if self.loss_scaler is not None:
self.loss_scaler.step(self.optimizer)
self.loss_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self._get_grad_norm()
def _get_grad_norm(self):
t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None
def get_lr(self): def get_lr(self):
lrs = [] lrs = []
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
if 'lr' in param_group: if 'd_coeff' in param_group:
lrs.append(param_group['d_coeff'])
elif 'lr' in param_group:
lrs.append(param_group['lr']) lrs.append(param_group['lr'])
return lrs return lrs
def set_lr(self, lr): def set_lr(self, lr):
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
if 'lr' in param_group: if 'd_coeff' in param_group:
param_group['d_coeff'] = lr
elif 'lr' in param_group:
param_group['lr'] = lr param_group['lr'] = lr
def get_global_grad_norm(self): def get_global_grad_norm(self):
return 0.0 return self._global_grad_norm
def traverse(self, *args, **kwargs): def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs) with ml.autocast():
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss") losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum() loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
stats = {} stats = {}
stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar") stats |= self.gather_attribute("scalar")
@ -194,6 +296,8 @@ class Engines(dict[str, Engine]):
def setup(self): def setup(self):
self._global_step = 0 self._global_step = 0
self._micro_step = 0 self._micro_step = 0
self._batch_size = 0
self._global_samples = 0
@property @property
def global_step(self): def global_step(self):
@ -203,6 +307,14 @@ class Engines(dict[str, Engine]):
def micro_step(self): def micro_step(self):
return self._micro_step return self._micro_step
@property
def batch_size(self):
return self._batch_size
@property
def global_samples(self):
return self._global_samples
def gather_attribute(self, *args, **kwargs): def gather_attribute(self, *args, **kwargs):
ret = {} ret = {}
for engine in self.values(): for engine in self.values():
@ -213,6 +325,50 @@ class Engines(dict[str, Engine]):
for engine in self.values(): for engine in self.values():
engine.dispatch_attribute(*args, **kwargs) engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}, callback=None, dtype=None, format=None):
if not format:
format = cfg.weights_format
format = format.lower()
if dtype is None:
dtype = cfg.trainer.dtype
for name, engine in self.items():
module = engine.module.state_dict()
lora = None
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
# safety
for k, v in module.items():
module[k] = v.to(dtype)
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
state_dict = {
'module': module,
'lora': lora,
"stats": {
"global_step": engine.global_step,
"micro_step": engine.micro_step,
"global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
},
"userdata": userdata,
"config": config.__dict__
}
if lora is None:
del state_dict['lora']
if callback:
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
torch_save(state_dict, save_path)
_logger.info(f"Exported {name} to {save_path}")
def save_checkpoint(self, tag=None): def save_checkpoint(self, tag=None):
if not tag: if not tag:
tag = cfg.trainer.save_tag tag = cfg.trainer.save_tag
@ -222,47 +378,67 @@ class Engines(dict[str, Engine]):
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items(): for name, engine in self.items():
engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag) if not engine._training:
continue
def load_checkpoint(self, tag=None): save_dir = cfg.ckpt_dir / name
try:
engine.save_checkpoint(save_dir, tag=tag)
except Exception as e:
_logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
checkpoints = [ d for d in list(save_dir.glob("*")) if d.is_dir() ]
checkpoints.sort(key=lambda x: x.stat().st_mtime)
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
for d in checkpoints:
if not d.is_dir() or not d.exists():
continue
_logger.info(f"Removing {d}")
for p in d.iterdir():
p.unlink()
d.rmdir()
def load_checkpoint(self, tag=None, training=True):
if not tag: if not tag:
tag = cfg.trainer.load_tag tag = cfg.trainer.load_tag
for name, engine in self.items(): for name, engine in self.items():
load_dir = cfg.ckpt_dir / name load_dir = cfg.ckpt_dir / name
engine.load_checkpoint( engine.load_checkpoint(
tag=tag, tag=tag,
load_dir=load_dir, load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading, load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=cfg.trainer.load_states, load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
load_lr_scheduler_states=cfg.trainer.load_states, load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
load_module_only=cfg.trainer.load_module_only,
) )
if cfg.trainer.restart_step_count: if cfg.trainer.restart_step_count:
engine.global_steps = 0 engine.global_steps = 0
engine.mocro_step = 0
engine.global_samples = 0
engine.tokens_processed = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg.hyperparameters.scheduler_type == "": if cfg.hyperparameters.scheduler_type == "":
self.set_lr(cfg.hyperparameters.learning_rate) self.set_lr(cfg.hyperparameters.learning_rate)
self._update_global_step() self._update()
self._update_micro_step()
def set_lr(self, lr): def set_lr(self, lr):
for engine in self.values(): for engine in self.values():
if not engine._training:
continue
engine.set_lr(lr) engine.set_lr(lr)
def _update_global_step(self): def _update(self):
for engine in self.values(): for engine in self.values():
self._global_step = max(self._global_step, engine.global_step) self._global_step = max(self._global_step, engine.global_step)
def _update_micro_step(self):
for engine in self.values():
self._micro_step = max(self._micro_step, engine.micro_step) self._micro_step = max(self._micro_step, engine.micro_step)
self._batch_size = max(self._batch_size, engine.batch_size)
def train_batch_size(self): self._global_samples = max(self._global_samples, engine.global_samples)
batch_size = 0
for engine in self.values():
batch_size = max(batch_size, engine.train_batch_size())
def eval(self): def eval(self):
for engine in self.values(): for engine in self.values():
@ -279,7 +455,10 @@ class Engines(dict[str, Engine]):
stats.update(flatten_dict({ name.split("-")[0]: stat })) stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats return stats
def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()): def quit(self):
cleanup_distributed()
def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0 total_elapsed_time = 0
stats: Any = dict() stats: Any = dict()
@ -287,10 +466,11 @@ class Engines(dict[str, Engine]):
if cfg.trainer.gc_mode == 'step': if cfg.trainer.gc_mode == 'step':
do_gc() do_gc()
batch = to_device(batch, device)
for name, engine in self.items(): for name, engine in self.items():
#torch.cuda.synchronize() if not engine._training:
continue
device = engine.device
if cfg.trainer.gc_mode == 'substep': if cfg.trainer.gc_mode == 'substep':
do_gc() do_gc()
@ -298,10 +478,9 @@ class Engines(dict[str, Engine]):
start_time = time.time() start_time = time.time()
tries = 4 tries = 4
n_ooms = torch.zeros([], device=cfg.device) n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations: batch = to_device(batch, device)
batch = to_device(batch, device)
if not cfg.trainer.check_for_oom: if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch ) res = feeder( engine=engine, batch=batch )
@ -311,7 +490,7 @@ class Engines(dict[str, Engine]):
res = feeder( engine=engine, batch=batch ) res = feeder( engine=engine, batch=batch )
break break
except RuntimeError as e: except RuntimeError as e:
print("Forward", str(e)) _logger.error(f"Forward: {str(e)}")
if "out of memory" not in str(e): if "out of memory" not in str(e):
self.save_checkpoint() self.save_checkpoint()
@ -329,7 +508,8 @@ class Engines(dict[str, Engine]):
do_gc() do_gc()
continue continue
all_reduce(n_ooms) if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0: if n_ooms.item() > 0:
self.save_checkpoint() self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!") raise RuntimeError("Out of memory during forward pass!")
@ -340,7 +520,7 @@ class Engines(dict[str, Engine]):
loss, engine_stats = res loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar") engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=cfg.device) n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations: if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu') batch = to_device(batch, 'cpu')
@ -348,10 +528,11 @@ class Engines(dict[str, Engine]):
if not cfg.trainer.check_for_oom: if not cfg.trainer.check_for_oom:
engine.backward(loss) engine.backward(loss)
else: else:
# to-do: properly handle when one GPU throws an OOM because it just halts
try: try:
engine.backward(loss) engine.backward(loss)
except RuntimeError as e: except RuntimeError as e:
print("Backwards:", str(e)) _logger.error(f"Backwards: {str(e)}")
if "out of memory" not in str(e): if "out of memory" not in str(e):
self.save_checkpoint() self.save_checkpoint()
@ -359,10 +540,13 @@ class Engines(dict[str, Engine]):
n_ooms += 1 n_ooms += 1
all_reduce(n_ooms) if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0: if n_ooms.item() > 0:
self.save_checkpoint() self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
raise RuntimeError("Out of memory during backwards pass!")
engine.step() engine.step()
@ -370,27 +554,36 @@ class Engines(dict[str, Engine]):
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time total_elapsed_time += elapsed_time
grad_norm = engine.get_global_grad_norm()
loss_scale = 1
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
loss_scale = engine.optimizer.loss_scale
if grad_norm is not None:
grad_norm /= loss_scale
stats.update( stats.update(
flatten_dict( flatten_dict(
{ {
name.split("-")[0]: dict( name.split("-")[0]: dict(
loss=loss.item(), **engine_stats,
lr=engine.get_lr()[0], lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
loss_scale=loss_scale if loss_scale != 1 else None,
elapsed_time=elapsed_time, elapsed_time=elapsed_time,
engine_step=engine.global_step, engine_step=engine.global_step,
**engine_stats, samples_processed=engine.global_samples,
tokens_processed=engine.tokens_processed,
) )
} }
), ),
) )
self._update_global_step() self._update()
self._update_micro_step()
stats["batch_size"] = self.train_batch_size() # len(batch["text"]) if len(self.keys()) > 1:
stats["elapsed_time"] = total_elapsed_time stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step stats["it"] = self.global_step
return stats return stats

View File

@ -25,28 +25,71 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
from ..utils.distributed import init_distributed, distributed_initialized from ..utils.distributed import init_distributed, distributed_initialized
from ..utils import wrapper as ml
from ..models.lora import freeze_non_lora_weights
if not distributed_initialized() and cfg.trainer.backend == "deepspeed": if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
init_distributed(init_deepspeed_dist) init_distributed(init_deepspeed_dist)
class Engine(DeepSpeedEngine): class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model']) self.hyper_config = None
if 'hyper_config' in kwargs:
self.hyper_config = kwargs['hyper_config']
kwargs.pop("hyper_config")
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
stats = {
"global_step": 0,
"micro_step": 0,
"global_samples": 0,
"tokens_processed": 0,
}
# kwargs['stats'] = None will return None when popped
maybe_stats = kwargs.pop('stats', stats)
if maybe_stats is not None:
stats = maybe_stats
super().__init__(None, *args, **kwargs) super().__init__(None, *args, **kwargs)
self._frozen_params = set() self._frozen_params = set()
def freeze(self): self.global_steps = stats["global_step"]
for p in self.module.parameters(): self.micro_steps = stats["micro_step"]
if p.requires_grad: self.global_samples = stats["global_samples"]
p.requires_grad_(False) self.tokens_processed = stats["tokens_processed"]
self._frozen_params.add(p)
self.max_nan_losses = 8
self.current_batch_size = 0
def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for param in frozen_params:
self._frozen_params.add( param )
return
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self): def unfreeze(self):
for p in self._frozen_params: for param in self._frozen_params:
p.requires_grad_(True) param.requires_grad_(True)
self._frozen_params.clear() self._frozen_params.clear()
@property
def _training(self):
return self.hyper_config.training
@property @property
def global_step(self): def global_step(self):
@ -54,7 +97,11 @@ class Engine(DeepSpeedEngine):
@property @property
def micro_step(self): def micro_step(self):
return self.micro_steps return self.micro_steps
@property
def batch_size(self):
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs): def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs) return gather_attribute(self.module, *args, **kwargs)
@ -66,17 +113,40 @@ class Engine(DeepSpeedEngine):
try: try:
if hasattr(self.optimizer, 'param_groups'): if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
param_group['lr'] = lr param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
else: else:
self.optimizer.set_lr(lr) self.optimizer.set_lr(lr)
except Exception as e: except Exception as e:
print(str(e)) _logger.warning(str(e))
# we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways
def load_checkpoint(self, load_dir, **kwargs ):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().load_checkpoint( load_dir, **kwargs )
def save_checkpoint(self, save_dir, **kwargs ):
# override to save the lora instead
if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().save_checkpoint( save_dir, **kwargs )
def traverse(self, *args, **kwargs): def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs) with ml.autocast():
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss") losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum() loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
stats = {} stats = {}
stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar") stats |= self.gather_attribute("scalar")

View File

@ -1,31 +1,67 @@
import argparse import argparse
import torch import torch
import torch.nn
from .data import get_symmap from .data import get_symmap
from .train import load_engines from .engines import load_engines
from .config import cfg
from .models.lora import lora_get_state_dict
from .utils.io import torch_save, torch_load
def load_models(): # yanks a LoRA from the training checkpoint
models = {} def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
engines = load_engines() if dtype is None:
for name in engines: dtype = cfg.inference.dtype
model = engines[name].module.cpu()
models[name] = model
return models format = save_path.stem[1:]
lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case
if lora is None and "module" in state_dict:
lora, module = lora_get_state_dict( state_dict["module"], split = True )
state_dict["module"] = module
if "lora" in state_dict:
state_dict["lora"] = None
# should raise an exception since there's nothing to extract, or at least a warning
if not lora:
return state_dict
# save lora specifically
# should probably export other attributes, similar to what SD LoRAs do
save_path = save_path.parent / f"lora.{format}"
torch_save( {
"module": lora,
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
}, save_path )
return state_dict
def main(): def main():
parser = argparse.ArgumentParser("Save trained model to path.") parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("path") parser.add_argument("--module-only", action='store_true')
args = parser.parse_args() parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
args, unknown = parser.parse_known_args()
models = load_models() if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]:
for name in models: raise Exception(f"Unknown requested format: {args.format}")
model = models[name]
outpath = f'{args.path}/{name}.pt' if args.module_only:
torch.save(model, outpath) cfg.trainer.load_module_only = True
print(f"Exported {name} to {outpath}")
if args.dtype != "auto":
cfg.trainer.weight_dtype = args.dtype
# necessary to ensure we are actually exporting the weights right
cfg.inference.backend = cfg.trainer.backend
engines = load_engines(training=False) # to ignore loading optimizer state
callback = None
engines.export(userdata={"symmap": get_symmap()}, callback=callback, format=args.format)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,53 +1,103 @@
import torch import torch
import torchaudio
import soundfile
import time
import logging
_logger = logging.getLogger(__name__)
from torch import Tensor
from einops import rearrange
from pathlib import Path
from .utils import to_device, set_seed, wrapper as ml
from PIL import Image, ImageDraw
import torchvision.transforms as transforms import torchvision.transforms as transforms
from .config import cfg from .config import cfg, Config
from .export import load_models from .models import get_models
from .data import get_symmap, _get_symbols from .engines import load_engines, deepspeed_available
from .data import get_symmap, tokenize
if deepspeed_available:
import deepspeed
class Classifier(): class Classifier():
def __init__( self, width=300, height=80, config=None, ckpt=None, device=cfg.get_device(), dtype="float32" ): def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
self.loading = True
# yes I can just grab **kwargs and forward them here
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
self.load_model()
self.loading = False
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
if config: if config:
_logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config ) cfg.load_yaml( config )
self.loading = True try:
cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
except Exception as e:
raise e # throw an error because I'm tired of silent errors messing things up for me
if amp is None:
amp = cfg.inference.amp
if dtype is None or dtype == "auto":
dtype = cfg.inference.weight_dtype
if device is None:
device = cfg.device
cfg.device = device
cfg.mode = "inferencing"
cfg.trainer.backend = cfg.inference.backend
cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype
self.device = device self.device = device
self.dtype = cfg.inference.dtype
if ckpt: self.amp = amp
self.load_model_from_ckpt( ckpt )
else:
self.load_model_from_cfg( config )
self.model.eval() self.model_kwargs = {}
self.width = width def load_model( self ):
self.height = height load_engines.cache_clear()
self.engines = load_engines(training=False, **self.model_kwargs)
for name, engine in self.engines.items():
if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.engines.eval()
self.symmap = get_symmap()
self.width = 300
self.height = 80
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)), transforms.Resize((self.height, self.width)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]) ])
_logger.info("Loaded model")
self.loading = False @torch.inference_mode()
def inference( self, image, temperature=1.0 ):
model = None
def load_model_from_ckpt( self, ckpt ): for name, engine in self.engines.items():
self.ckpt = ckpt model = engine.module
self.model = torch.load(self.ckpt).to(self.device)
def load_model_from_cfg( self, config_path ):
models = load_models()
for name in models:
model = models[name]
self.model = model.to(self.device)
break break
def inference( self, image, temperature=1.0 ): image = self.transform(image).to(self.device).to(self.dtype)
image = self.transform(image).to(self.device)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
answer = model( image=[image], sampling_temperature=temperature )
answer = [ "".join(answer) ]
answer = self.model( image=[image], sampling_temperature=temperature )
answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
return answer return answer

9
image_classifier/models/__init__.py Executable file → Normal file
View File

@ -1,18 +1,19 @@
from .base import Model from .base import Model
def get_model(cfg): def get_model(cfg, training=False):
name = cfg.name name = cfg.name
model = Model( model = Model(
n_tokens=cfg.tokens, n_tokens=cfg.tokens,
n_len=cfg.len, n_len=cfg.len,
d_model=cfg.dim, d_model=cfg.dim,
d_resnet=cfg.resnet,
) )
model._cfg = cfg model.config = cfg
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model return model
def get_models(models): def get_models(models, training=False):
return { model.full_name: get_model(model) for model in models } return { model.full_name: get_model(model, training=training) for model in models }

View File

@ -12,7 +12,7 @@ from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from torchvision.models import resnet18 from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from ..data import get_symmap from ..data import get_symmap
@ -20,12 +20,12 @@ class Model(nn.Module):
def __init__( def __init__(
self, self,
n_tokens: int = 0, # number of token types n_tokens: int = 0, # number of token types
n_len: int = 6, # how long a sequence can be n_len: int = 12, # how long a sequence can be
d_model: int = 512, d_model: int = 512,
d_resnet: int = 18,
): ):
super().__init__() super().__init__()
_symmap = get_symmap() _symmap = get_symmap()
self.symmap = { f'{v}': k for k, v in _symmap.items() } self.symmap = { f'{v}': k for k, v in _symmap.items() }
self.symmap['0'] = "" self.symmap['0'] = ""
@ -36,8 +36,26 @@ class Model(nn.Module):
self.n_tokens = n_tokens self.n_tokens = n_tokens
self.n_len = n_len + 2 # start/stop tokens self.n_len = n_len + 2 # start/stop tokens
self.d_model = d_model self.d_model = d_model
self.d_resnet = d_resnet
self.resnet = resnet18(pretrained=False) ResNet = resnet18
if d_resnet == 18:
print("Using resnet18")
ResNet = resnet18
elif d_resnet == 34:
print("Using resnet34")
ResNet = resnet34
elif d_resnet == 50:
print("Using resnet50")
ResNet = resnet50
elif d_resnet == 101:
print("Using resnet101")
ResNet = resnet101
elif d_resnet == 152:
print("Using resnet152")
ResNet = resnet152
self.resnet = ResNet(pretrained=False)
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len ) self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
self.accuracy_metric = MulticlassAccuracy( self.accuracy_metric = MulticlassAccuracy(
@ -61,33 +79,29 @@ class Model(nn.Module):
sampling_temperature: float = 1.0, sampling_temperature: float = 1.0,
): ):
x_list = torch.stack( image, dim=0 ) logits = self.resnet( torch.stack( image, dim=0 ) )
logits = logits.view(logits.size(0), self.n_len, self.n_tokens).permute(1, 0, 2)
x = self.resnet( x_list )
y = x.view(x.size(0), self.n_len, self.n_tokens)
# either of these should do, but my VALL-E forward pass uses this, so might as well keep to it pred = logits.argmax(dim=2)
# pred = y.argmax(dim=2)
pred = Categorical(logits=y / sampling_temperature).sample()
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
if text is not None: if text is not None:
y_list = rearrange(pad_sequence(text), "t b -> b t") labels = rearrange(pad_sequence(text), "t b -> b t").permute(1, 0)
loss = []
loss = 0
for i in range(self.n_len): for i in range(self.n_len):
if i >= y_list.shape[1]: if i >= labels.shape[0]:
break break
loss += F.cross_entropy( y[:, i], y_list[:, i] ) loss.append( F.cross_entropy(logits[i], labels[i]) )
self.loss = dict( self.loss = dict(
nll=loss nll = sum( loss ) / len( loss ),
) )
self.stats = dict( self.stats = dict(
acc = self.accuracy_metric( pred, y_list ), acc = self.accuracy_metric( pred, labels ),
precision = self.precision_metric( pred, y_list ), precision = self.precision_metric( pred, labels ),
) )
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
return answer return answer

View File

@ -0,0 +1,214 @@
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
from functools import partial
import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from transformers.pytorch_utils import Conv1D
from torch import Tensor, nn
import math
from typing import Optional, List
from ..utils import passes_policy
# LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
class LoRALinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
rank: int = 4,
alpha: int = 1,
dropout: float = 0.1,
merge_weights: bool = False,
**kwargs,
):
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
self.rank = rank
self.alpha = alpha
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.merge_weights = merge_weights
self.merged = False
self.enabled = True
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
self.scaling = self.alpha / self.rank
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
super().reset_parameters()
# super silly but necessary because nn.Linear's constructor calls this
if hasattr(self, 'lora_A'):
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
nn.init.zeros_( self.lora_B )
def train(self, mode: bool = True):
super().train(mode)
# training, separate lora from base weights
if mode and self.merge_weights and self.merged:
self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
self.merged = False
# not training, merge lora to base weights
if not mode and self.merge_weights and not self.merged:
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if not self.merged and self.enabled:
result = F.linear(x, self.weight, bias=self.bias)
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
return F.linear(x, self.weight, bias=self.bias)
@classmethod
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
# Uses parametrization to inject LoRA weights
# Pros: should work with any Linears
# Cons: TBD
class ParameterizedLoRA(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
rank: int = 4,
alpha: int = 1,
dropout: float = 0.1,
device = None,
dtype = None
):
super().__init__()
self.rank = rank
self.alpha = alpha
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
self.scaling = self.alpha / self.rank
self.enabled = True
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
nn.init.zeros_( self.lora_B )
def forward(self, x: torch.Tensor):
if self.enabled:
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
return x
@classmethod
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
# swap because we're feeding the output as our input
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
@classmethod
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
in_channels, out_channels = layer.weight.shape
# swap because we're feeding the output as our input
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
for *parent, k in modules:
name = '.'.join(parent)
layer = getattr( model.get_submodule(name), k )
if isinstance( layer, nn.Linear ):
target = nn.Linear
klass = ParameterizedLoRA if use_parametrize else LoRALinear
replacer = klass.from_linear
elif isinstance( layer, nn.Conv1d ):
target = nn.Conv1d
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
replacer = klass.from_conv1d
elif isinstance( layer, Conv1D ):
target = Conv1D
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
replacer = klass.from_conv1d
else:
continue
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
if use_parametrize:
parametrize.register_parametrization( layer, "weight", replacement )
else:
setattr( model.get_submodule(name), k, replacement )
return enable_lora( model )
def enable_lora( model, mode = True ):
for name, module in model.named_modules():
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
continue
module.enabled = mode
return model
def disable_lora( model ):
return enable_lora( model, False )
def freeze_non_lora_weights( model, embeddings = False ):
frozen_params = []
for name, param in model.named_parameters():
should = 'lora_' in name or (embeddings and "_emb" in name)
param.requires_grad_(should)
if not should:
frozen_params.append( param )
return frozen_params
def lora_get_state_dict( state_dict, split = True ):
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
if not split:
return lora
return lora, { name: param for name, param in state_dict.items() if "lora_" not in name }
def lora_load_state_dict( model, state_dict ):
return model.load_state_dict( state_dict, strict = False )

120
image_classifier/plot.py Normal file
View File

@ -0,0 +1,120 @@
#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
from .config import cfg
def plot(paths, args):
dfs = []
for path in paths:
with open(path, "r") as f:
text = f.read()
rows = []
pattern = r"(\{.+?\})\.\n"
for row in re.findall(pattern, text, re.DOTALL):
try:
row = json.loads(row)
except Exception as e:
continue
for model in args.models:
if f'{model.name}.{args.xs}' not in row:
continue
rows.append(row)
break
df = pd.DataFrame(rows)
if "name" in df:
df["name"] = df["name"].fillna("train")
else:
df["name"] = "train"
df["group"] = str(path.parents[args.group_level])
df["group"] = df["group"] + "/" + df["name"]
dfs.append(df)
df = pd.concat(dfs)
if args.min_x is not None:
for model in args.models:
df = df[args.min_x < df[f'{model.name}.{args.xs}']]
if args.max_x is not None:
for model in args.models:
df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
for gtag, gdf in sorted(
df.groupby("group"),
key=lambda p: (p[0].split("/")[-1], p[0]),
):
for model in args.models:
x = f'{model.name}.{args.xs}'
for ys in args.ys:
y = f'{model.name}.{ys}'
if gdf[y].isna().all():
continue
if args.min_y is not None:
gdf = gdf[args.min_y < gdf[y]]
if args.max_y is not None:
gdf = gdf[gdf[y] < args.max_y]
gdf[y] = gdf[y].ewm(10).mean()
gdf.plot(
x=x,
y=y,
label=f"{y}",
ax=plt.gca(),
marker="x" if len(gdf) < 100 else None,
alpha=0.7,
)
plt.gca().legend(
loc="center left",
fancybox=True,
shadow=True,
bbox_to_anchor=(1.04, 0.5),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--xs", default="engine_step")
parser.add_argument("--ys", nargs="+", default="")
parser.add_argument("--model", nargs="+", default="*")
parser.add_argument("--min-x", type=float, default=-float("inf"))
parser.add_argument("--min-y", type=float, default=-float("inf"))
parser.add_argument("--max-x", type=float, default=float("inf"))
parser.add_argument("--max-y", type=float, default=float("inf"))
parser.add_argument("--filename", default="log.txt")
parser.add_argument("--group-level", default=1)
args, unknown = parser.parse_known_args()
path = cfg.rel_path / "logs"
paths = path.rglob(f"./*/{args.filename}")
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
if args.ys == "":
args.ys = ["loss"]
plot(paths, args)
out_path = cfg.rel_path / "metrics.png"
plt.savefig(out_path, bbox_inches="tight")

View File

@ -0,0 +1,204 @@
import math
import torch
import torch.nn.functional as F
import numpy as np
from torch import Tensor, einsum, nn
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
if factor == 1.0 or previous is None:
return logits
unique = set()
priors = reversed(previous)
for distance, token in enumerate(priors):
# skip if we're only applying the decay once
if one_time and token in unique:
continue
distance += 1
logits[:, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
unique.add(token)
return logits
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
# `length` is the length of the sequence currently
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
# `token` is the stop token.
def length_penalize( logits, length, factor=0.0, token=-1 ):
if factor == 0.0:
return logits
logits[:, token] /= (length ** factor)
return logits
# Simple way to ban tokens
def ban_tokens( logits, tokens ):
for token in tokens:
# token not in logits
if logits.shape[-1] >= token:
continue
logits[:, token] = -float("inf")
return logits
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens per batch example in the output
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens > 1:
# Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
# credit to https://github.com/LostRuins/koboldcpp/pull/464 // https://github.com/kalomaze/koboldcpp/tree/dynamic-temp
def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.0, k = 10, sigmoidCenterPoint = 0.5 ):
# loop over logits[:], as the NAR will have logits.shape[0] > 1
for i in range(logits.shape[0]):
sum_exp = 0.0
maximum = torch.max( logits[i] )
for logit in logits[i]:
sum_exp += math.exp( logit - maximum )
prob_max_token_before_temp = 1.0 / sum_exp
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
logits[i] /= dynamic_temperature
return logits
# picks the top K tokens amongst a batch of logits
# logits: [Tensor] list of logits
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
def top_k_logits_list( logits_list, k ):
# ( batch, tokens ) => ( batch x tokens )
logits = torch.cat( logits_list )
candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
for i, index in enumerate(candidates):
t = []
N = np.prod(logits.size())
for n in logits.size():
N //= n
t.append(index // N)
index %= N
candidates[i] = tuple(t)
return candidates
# Credit to: https://github.com/basusourya/mirostat/
# performs mirostat-based sampling
# logits: Tensor of logit probabilities
# state: the mirostat state
def mirostat_sample( logits, state = None ):
def compute_k(prob, n, tau):
num = 0
den = 0
for i in range(100):
b = prob[i]/prob[i+1]
t = (i+2)/(i+1)
num += math.log(b)*math.log(t)
den += math.log(t)**2
s = num/den
eps = s-1
k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s)
k = round(k)
return k
if "max_surprise" not in state:
state["max_surprise"] = state["tau"] * 2
if "error_surprise" not in state:
state["error_surprise"] = 0
if "running_total_surprise" not in state:
state["running_total_surprise"] = 0
sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True )
prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist()
k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1
sorted_logits = sorted_logits[0:k]
sorted_indices = sorted_indices[0:k]
prob_topk = torch.softmax(sorted_logits, dim = 0)
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
state["index_surprise"] = math.log2(1/prob_original[prev_i])
state["running_total_surprise"] += state["index_surprise"]
state["error_surprise"] = state["index_surprise"] - state["tau"]
state["max_surprise"] -= state["eta"] * state["error_surprise"]
state["token"] = sorted_indices[prev_i]
return state
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
# performs DRY sampling
# * (honestly it looks close to rep pen anyways but what do I know)
# `logits` are the scores used to sample against
# `previous` are the prior tokens to penalize with
# `factor` is the scalar multiplier
# `base` is the base number to raise to the (length - allowed_length)th power
# `allowed_length` limits the range to apply DRY to
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
if factor == 0.0 or previous is None:
return logits
lengths = {}
for i, token in enumerate( previous ):
length = 1
while length < max(allowed_length, 50):
j = i - length
# Start of input reached.
if j < 0:
break
# Start of match reached.
if previous[j] != previous[-length-1]:
break
length += 1
lengths[token] = max(length, lengths[token]) if token in lengths else length
for token, length in lengths.items():
if length < allowed_length:
break
logits[:, token] -= factor * base ** (length - allowed_length)
return logits

View File

@ -4,7 +4,7 @@ from .config import cfg
from .data import create_train_val_dataloader from .data import create_train_val_dataloader
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils.trainer import load_engines from .utils.distributed import is_global_leader
import json import json
import logging import logging
@ -12,53 +12,43 @@ import random
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import traceback import traceback
import shutil
from collections import defaultdict from collections import defaultdict
from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import argparse
from PIL import Image, ImageDraw
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def train_feeder(engine, batch): def train_feeder(engine, batch):
engine( image=batch["image"], text=batch["text"] ) with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
losses = engine.gather_attribute("loss") engine( image=batch["image"], text=batch["text"] )
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum() losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum()
stats = {} stats = {}
stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()} stats |= {k: v.item() for k, v in stat.items()}
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
return loss, stats return loss, stats
@torch.inference_mode() @torch.inference_mode()
def run_eval(engines, eval_name, dl): def run_eval(engines, eval_name, dl):
engines_stats = {
'eval': eval_name
}
model = None
names = []
for name, engine in engines.items():
names.append(name)
model = engine
break
stats = defaultdict(list) stats = defaultdict(list)
stats['loss'] = [] stats['loss'] = []
def process( name, batch, resps_list ): def process( name, batch, res, loss ):
for path, ref, hyp in zip(batch["path"], batch["text"], hyp):
continue
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
for path, ref, hyp in zip(batch["path"], batch["text"], res): for path, ref, hyp in zip(batch["path"], batch["text"], res):
hyp = hyp.replace('<s>', "").replace("</s>", "") hyp = hyp.replace('<s>', "").replace("</s>", "")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png")
@ -67,36 +57,74 @@ def run_eval(engines, eval_name, dl):
image = Image.open(path).convert('RGB') image = Image.open(path).convert('RGB')
image.save(hyp_path) image.save(hyp_path)
losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum().item()
stats['loss'].append(loss) stats['loss'].append(loss)
processed = 0
while processed < cfg.evaluation.size:
batch = to_device(next(iter(dl)), cfg.device)
# limit to eval batch size in the event we somehow have a weird dataloader
for key in batch.keys():
batch[key] = batch[key][:cfg.evaluation.batch_size]
processed += len(batch["text"])
for name in engines:
engine = engines[name]
res = engine( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum().item()
process( name, batch, res, loss )
stats = {k: sum(v) / len(v) for k, v in stats.items()} stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats })) engines_stats = {
f'{name}.{eval_name}': stats,
iteration = engines.global_step "it": engines.global_step,
engines_stats['it'] = iteration }
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def main(): def train():
parser = argparse.ArgumentParser("ResNet Image Classifier")
parser.add_argument("--eval", action="store_true", default=None)
args, unknown = parser.parse_known_args()
# create log folder
setup_logging(cfg.log_dir) setup_logging(cfg.log_dir)
# copy config yaml to backup
if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def eval_fn(engines): def eval_fn(engines):
do_gc()
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
try: try:
run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl) run_eval(engines, "val", val_dl)
except Exception as e: except Exception as e:
print("Error occurred while performing eval:", str(e)) _logger.warning(f"Error occurred while performing eval: {str(e)}")
print(traceback.format_exc()) _logger.warning(traceback.format_exc())
engines.train()
do_gc() do_gc()
if args.eval:
return eval_fn(engines=trainer.load_engines())
"""
if cfg.trainer.load_webui:
from .webui import start
start(lock=False)
"""
trainer.train( trainer.train(
train_dl=train_dl, train_dl=train_dl,
train_feeder=train_feeder, train_feeder=train_feeder,
@ -104,4 +132,5 @@ def main():
) )
if __name__ == "__main__": if __name__ == "__main__":
main() # to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
train()

View File

@ -7,4 +7,7 @@ from .utils import (
to_device, to_device,
tree_map, tree_map,
do_gc, do_gc,
set_seed,
passes_policy,
get_devices
) )

View File

@ -8,6 +8,10 @@ import socket
from functools import cache, wraps from functools import cache, wraps
from typing import Callable from typing import Callable
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def get_free_port(): def get_free_port():
sock = socket.socket() sock = socket.socket()
sock.bind(("", 0)) sock.bind(("", 0))
@ -15,13 +19,18 @@ def get_free_port():
_distributed_initialized = False _distributed_initialized = False
def init_distributed( fn ): def init_distributed( fn, *args, **kwargs ):
fn() torch.cuda.set_device(local_rank())
fn(*args, **kwargs)
_distributed_initialized = True _distributed_initialized = True
def distributed_initialized(): def distributed_initialized():
return _distributed_initialized return _distributed_initialized
def cleanup_distributed():
dist.barrier()
dist.destroy_process_group()
@cache @cache
def fix_unset_envs(): def fix_unset_envs():
envs = dict( envs = dict(
@ -44,10 +53,12 @@ def fix_unset_envs():
def local_rank(): def local_rank():
return int(os.getenv("LOCAL_RANK", 0)) return int(os.getenv("LOCAL_RANK", 0))
def global_rank(): def global_rank():
return int(os.getenv("RANK", 0)) return int(os.getenv("RANK", 0))
def world_size():
return int(os.getenv("WORLD_SIZE", 1))
def is_local_leader(): def is_local_leader():
return local_rank() == 0 return local_rank() == 0
@ -86,4 +97,7 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
if fn is None: if fn is None:
return wrapper return wrapper
return wrapper(fn) return wrapper(fn)
def ddp_model(model):
return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)

View File

@ -0,0 +1,88 @@
import torch
import json
from pathlib import Path
from safetensors import safe_open as sft_load
from safetensors.torch import save_file as sft_save
def coerce_path( path ):
return path if isinstance( path, Path ) else Path(path)
def pick_path( path, *suffixes ):
suffixes = [*suffixes]
for suffix in suffixes:
p = path.with_suffix( suffix )
if p.exists():
return p
return path
def is_dict_of( d, t ):
if not isinstance( d, dict ):
return False
return all([ isinstance(v, torch.Tensor) for v in d.values() ])
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
metadata = None
# is a state_dict, no need to coerce
if is_dict_of( data, torch.Tensor ):
return data, metadata
# is maybe a dict with a state dict + metadata, coerce it
metadata = {}
target = module_key
if not target:
for k, v in data.items():
# is a dict of tensors, our target
if is_dict_of( v, torch.Tensor ):
target = k
continue # continue to iterate to grab other metadata
# not a dict of tensors, put it as metadata
try:
metadata[k] = json.dumps(v)
except Exception as e:
pass
if not target:
raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
return data[target], metadata
def torch_save( data, path, module_key=None ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
return sft_save( data, path, metadata )
return torch.save( data, path )
def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
state_dict = {}
with sft_load(path, framework=framework, device=device) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if load_metadata:
metadata = f.metadata()
for k, v in metadata.items():
try:
metadata[k] = json.loads( v )
except Exception as e:
pass
state_dict = { module_key: state_dict } | metadata
return state_dict
return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )

212
image_classifier/utils/sampler.py Executable file → Normal file
View File

@ -1,48 +1,164 @@
""" from dataclasses import dataclass
A sampler that balances data by key_fns. from typing import Any
import random
MIT License
import torch
Copyright (c) 2023 Zhe Niu from torch.utils.data import Sampler
niuzhe.nz@outlook.com from .distributed import global_rank, local_rank, world_size
"""
# Randomly picks an index from an array of indices
import random class PoolSampler():
def __init__( self, pool = [], keep_all = False, shuffle = False ):
self.length = len(pool)
class Sampler: self.shuffle = shuffle
def __init__(self, l, key_fns): self.global_pool = pool if keep_all else None
self.tree = self._build(l, key_fns) self.global_indices = [ i for i in range(self.length) ]
self.reset()
def _build(self, l, key_fns) -> dict[dict, list]:
if not key_fns: def reset(self):
return l self.current_pool = [ i for i in self.global_indices ]
if self.shuffle:
tree = {} random.shuffle(self.current_pool)
key_fn, *key_fns = key_fns def sample(self, pool = None):
if pool is None:
for x in l: pool = self.global_pool
k = key_fn(x) # check if we need to reset
index = random.choice( self.current_pool )
if k in tree: # remove from pool
tree[k].append(x) self.current_pool.remove(index)
else: # reset if needed
tree[k] = [x] if len(self.current_pool) == 0:
self.reset()
for k in tree: # map indices to our real values
tree[k] = self._build(tree[k], key_fns) return pool[index] if pool is not None else index
return tree def __len__(self):
return self.length # len(self.current_pool)
def _sample(self, tree: dict | list):
if isinstance(tree, list): def __iter__(self):
ret = random.choice(tree) while len(self.current_pool) > 0:
else: yield self.sample()
key = random.choice([*tree.keys()])
ret = self._sample(tree[key]) def __call__(self, *args, **kwargs):
return ret return self.sample(*args, **kwargs)
def sample(self): def get_state(self):
return self._sample(self.tree) return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
def set_state(self, state):
self.length = state["length"]
self.global_pool = state["global_pool"]
self.global_indices = state["global_indices"]
self.current_pool = state["current_pool"]
# "Samples" through a fixed sequence from 0 to length
# Necessary for our "shuffle+sort by duration+interleave" sampling method
# Allows saving and loading state
class OrderedSampler(Sampler):
def __init__( self, length ):
self.position = 0
self.length = length
def __len__(self):
return self.length
def __iter__(self):
if self.position >= self.length:
self.position = 0
while self.position < self.length:
yield self.position
self.position += 1
def get_state(self):
return { "position": self.position, "length": self.length }
def set_state(self, state):
self.position = state["position"]
self.length = state["length"]
# Like the above, but will batch based on token count
class BatchedOrderedSampler(Sampler):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
self.position = 0
self.batches = []
self.shuffle = shuffle
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
current_batch = []
current_size = 0
current_index = 0
for key, bucket in buckets.items():
for path, duration in bucket:
# flush
should_flush = False
if max_duration > 0 and current_size + duration > max_duration:
should_flush = True
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
should_flush = True
if should_flush and len(current_batch) > 0:
self.batches.append( current_batch )
current_batch = []
current_size = 0
current_batch.append( current_index )
current_index += 1
current_size += duration
if self.shuffle:
random.shuffle(self.batches)
def __len__(self):
return len(self.batches)
def __iter__(self):
if self.position >= len(self.batches):
self.position = 0
if self.shuffle:
random.shuffle(self.batches)
while self.position < len(self.batches):
yield self.batches[self.position]
self.position += 1
def get_state(self):
return { "position": self.position, "batches": self.batches }
def set_state(self, state):
self.position = state["position"]
self.batches = state["batches"]
# Randomly samples indices from a given sequence from 0 to length
# Allows saving and loading state
class RandomSampler(Sampler):
def __init__( self, length ):
self.position = 0
self.length = length
self.generator = torch.Generator()
self.perm = torch.randperm(self.length, generator=self.generator)
def __len__(self):
return self.length
def __iter__(self):
if self.position >= self.length:
self.position = 0
self.perm = torch.randperm(self.length, generator=self.generator)
while self.position < self.length:
yield self.perm[self.position]
self.position += 1
def get_state(self):
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
def set_state(self, state):
self.position = state["position"]
self.length = state["length"]
self.perm = state["perm"]
self.generator.set_state(state["generator"])

View File

@ -4,12 +4,13 @@
import humanize import humanize
import json import json
import os
import logging import logging
import numpy as np
import random import random
import selectors import selectors
import sys import sys
import torch import torch
import os
from functools import cache from functools import cache
from torch.distributed import broadcast_object_list from torch.distributed import broadcast_object_list
@ -18,9 +19,10 @@ from tqdm import tqdm
from typing import Protocol from typing import Protocol
from ..config import cfg from ..config import cfg
from .distributed import init_distributed, distributed_initialized
from .distributed import ( from .distributed import (
fix_unset_envs, init_distributed,
distributed_initialized,
world_size,
global_leader_only, global_leader_only,
global_rank, global_rank,
is_global_leader, is_global_leader,
@ -28,73 +30,15 @@ from .distributed import (
local_leader_only, local_leader_only,
) )
from ..engines import Engine, Engines, TrainFeeder, default_feeder from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
from ..models import get_models
from .utils import to_device, do_gc from .utils import to_device, do_gc, truncate_json
from ..utils import wrapper as ml from ..utils import wrapper as ml
from ..data import get_symmap # should decouple from this trainer script
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_engines: Engines
_command: str _command: str
def get_global_step():
try:
return _engines.global_step
except:
return None
def get_micro_step():
try:
return _engines.micro_step
except:
return None
def get_cmd():
try:
return _command
except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
get_iteration = get_global_step
def load_engines():
models = get_models(cfg.models.get())
engines = dict()
for name in models:
model = models[name]
optimizer = None
lr_scheduler = None
if cfg.hyperparameters.optimizer.lower() == "adamw":
optimizer = ml.AdamW(
model.parameters(),
lr=cfg.hyperparameters.learning_rate,
betas=(0.9, 0.96),
eps=1e-07,
weight_decay=0.01,
)
if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_path))
engines[name] = Engine(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
engines = Engines(engines)
engines.setup()
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
return engines
class EvalFn(Protocol): class EvalFn(Protocol):
def __call__(self, *, engines: Engines): def __call__(self, *, engines: Engines):
@ -151,17 +95,16 @@ def _non_blocking_input():
l[0] = s l[0] = s
if distributed_initialized(): if world_size() > 1:
broadcast_object_list(l, src=0) broadcast_object_list(l, src=0)
_command = l[0] _command = l[0]
return _command return _command
def _make_infinite_epochs(dl): def _make_infinite_epochs(dl):
while True: while True:
_logger.info("New epoch starts.") #_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True) yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
@local_leader_only(default=None) @local_leader_only(default=None)
@ -172,30 +115,32 @@ def logger(data):
def seed(seed): def seed(seed):
# Set up random seeds, after fork() # Set up random seeds, after fork()
random.seed(seed + global_rank()) random.seed(seed + global_rank())
#np.random.seed(seed + global_rank()) np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank()) torch.manual_seed(seed + global_rank())
def train( def train(
train_dl: DataLoader, train_dl: DataLoader,
train_feeder: TrainFeeder = default_feeder, train_feeder: TrainFeeder = default_feeder,
eval_fn: EvalFn = lambda x: ..., eval_fn: EvalFn = lambda x: ...,
logger: Logger = logger, logger: Logger = logger,
): ):
fix_unset_envs()
engines = load_engines() engines = load_engines()
# validate if there's at least one model to train
found = False
for name, engine in engines.items():
if engine.training:
found = True
break
if not found:
raise Exception('Training, but no model loaded set to train...')
""" """
if is_local_leader(): if is_local_leader():
cfg.dump() cfg.dump()
_logger.info(cfg) _logger.info(cfg)
""" """
# Setup global engines
global _engines
_engines = engines
events = [] events = []
eval_fn = global_leader_only(eval_fn) eval_fn = global_leader_only(eval_fn)
@ -203,15 +148,20 @@ def train(
# Pre-loop command # Pre-loop command
command = _non_blocking_input() command = _non_blocking_input()
if command in ["eval", "eval_quit"]: if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines) eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]: if command in ["quit", "eval_quit"]:
engines.quit()
return return
last_save_step = engines.global_step last_save_step = engines.global_step
last_eval_step = 0 last_eval_step = 0
"""
if cfg.distributed:
train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths)))
"""
# Training loop # Training loop
for batch in _make_infinite_epochs(train_dl): for batch in _make_infinite_epochs(train_dl):
if engines.global_step >= cfg.trainer.iterations: if engines.global_step >= cfg.trainer.iterations:
@ -219,17 +169,15 @@ def train(
#batch = to_device(batch, torch.cuda.current_device()) #batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(batch=batch, feeder=train_feeder) stats = engines.step(batch=batch, feeder=train_feeder)
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
stats['it'] = iteration
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
del stats['batch_size']
del stats['wall_time']
del stats['global_step']
elapsed_time = stats.get("elapsed_time", 0) elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.") try:
metrics = json.dumps(stats)
except Exception as e:
metrics = str(stats)
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
command = _non_blocking_input() command = _non_blocking_input()
@ -267,29 +215,48 @@ def train(
if "lr" in command: if "lr" in command:
rate = float(command.split(" ")[-1]) rate = float(command.split(" ")[-1])
engines.set_lr(rate) try:
print("Updating LR to:", rate) engines.set_lr(rate)
_logger.info(f"Updating LR to: {rate}")
except Exception as e:
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
if "export" in command:
train_dl.dataset.save_state_dict()
engines.save_checkpoint()
last_save_step = engines.global_step
if is_global_leader():
engines.export(userdata={"symmap": get_symmap()})
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"] saving_commands = ["save"]
export_commands = ["export"]
if cfg.trainer.save_on_quit: if cfg.trainer.save_on_quit:
saving_commands.append("quit") saving_commands.append("quit")
if cfg.trainer.export_on_quit:
export_commands.append("quit")
if cfg.trainer.export_on_save:
export_commands.append("save")
if engines.global_step != last_save_step: if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands: if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
train_dl.dataset.save_state_dict()
engines.save_checkpoint() engines.save_checkpoint()
last_save_step = engines.global_step last_save_step = engines.global_step
if command in export_commands and is_global_leader():
engines.export(userdata={"symmap": get_symmap()})
if engines.global_step != last_eval_step: if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
do_gc()
engines.eval()
eval_fn(engines=engines)
engines.train()
last_eval_step = engines.global_step last_eval_step = engines.global_step
eval_fn(engines=engines)
if command in ["quit"]: if command in ["quit"]:
return engines.quit()
return

View File

@ -7,8 +7,16 @@ from .distributed import global_rank, local_rank, global_leader_only
import gc import gc
import logging import logging
import pandas as pd import pandas as pd
import numpy as np
import re import re
import torch import torch
import random
import time
import psutil
import math
import logging
_logger = logging.getLogger(__name__)
from coloredlogs import ColoredFormatter from coloredlogs import ColoredFormatter
from logging import StreamHandler from logging import StreamHandler
@ -16,9 +24,16 @@ from pathlib import Path
from torch import Tensor, nn from torch import Tensor, nn
from tqdm.auto import tqdm from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload from typing import Callable, TypeVar, overload
from contextlib import contextmanager
T = TypeVar("T") T = TypeVar("T")
def truncate_json( str ):
def fun( match ):
return "{:.4f}".format(float(match.group()))
return re.sub(r"\d+\.\d{8,}", fun, str)
def do_gc(): def do_gc():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -28,6 +43,14 @@ def flatten_dict(d):
return records[0] if records else {} return records[0] if records else {}
def set_seed(seed=None):
if not seed:
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def _get_named_modules(module, attrname): def _get_named_modules(module, attrname):
for name, module in module.named_modules(): for name, module in module.named_modules():
if hasattr(module, attrname): if hasattr(module, attrname):
@ -155,5 +178,363 @@ def tree_map(fn: Callable, x):
return x return x
def to_device(x: T, device) -> T: def to_device(x: T | None, *args, **kwargs) -> T:
return tree_map(lambda t: t.to(device), x) if x is None:
return
return tree_map(lambda t: t.to(*args, **kwargs), x)
def coalese( *arg, return_last=True ):
return [ x for x in arg if x is not None ][-1 if return_last else 0]
# checks if a module name is within a given whitelist/blacklist policy dict
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
return False
if "include" in policy:
for term in policy["include"]:
if term in name:
return True
return False
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
return wrapper
# handles migrating an input tensor to a given devicve
def auto_align_inputs_forward( module, device=None, name = None ):
func = module.forward
if device is None:
if hasattr( module, 'device' ):
device = module.device
else:
try:
device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
except Exception as e:
return func
def wrapper( *args, **kwargs ):
args = [*args]
# search through args and kwargs for any Tensor arguments
for i, arg in enumerate(args):
if not isinstance( arg, torch.Tensor ):
continue
args[i] = arg.to( device=device )
for k, v in kwargs.items():
if not isinstance( v, torch.Tensor ):
continue
kwargs[k] = v.to( device=device )
# disgusting patch
if "position_embeddings" in kwargs:
kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
return func( *args, **kwargs )
return wrapper
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
# generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
in_features = m.in_features,
out_features = m.out_features,
bias = m.bias is not None,
) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, mode="math", verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
mode = mode,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
# trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target, dim=0, random=True ):
# trim
if target < weight.shape[dim]:
return weight[:target]
# expand
if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack(
[ x for x in weight ] +
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
)
return weight
def get_devices():
return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu']
# grabs the memory properties of a given device
def get_device_properties( device ):
if 'cuda' in device:
props = torch.cuda.get_device_properties(device)
free, total = torch.cuda.mem_get_info(device)
else:
props = psutil.virtual_memory()
free, total = props.available, props.total
return {"name": device, "total": total, "free": free, "props": props}
# gets the rough size for a given module's parameters
def get_module_size( module ):
param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
return param_size + buffer_size
# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way
# it'd be better to just attach to layers itself rather than every single module
# assigns modules to requested devices for a given policy
def get_model_offload_policy(module, policy=None):
# handle any other weird values this is set to
if not isinstance(policy, dict):
policy = {}
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
if "include" not in policy:
policy["include"] = ["model"]
if "limits" not in policy:
policy["limits"] = []
if "assign" not in policy:
policy["assign"] = []
if "devices" not in policy:
policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget
# create initial device info
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
# filter
modules = [ (name, size) for name, size in modules if name and size ]
total_size = sum([size for name, size in modules])
# set caps if requested in the policy
for i, cap in enumerate(policy["limits"]):
# no limit, skip
if cap <= 0:
continue
# is fractional, scale to total size
if cap < 1:
cap = math.floor(total_size * cap)
# available space is below cap, don't set
if devices[i]["free"] < cap:
continue
# cap to requested size
devices[i]["free"] = cap
# assign if specific parts of the model are requested for assignment
if policy["assign"]:
discarded = []
# yuck, there has to be a better way
for device_index, includes in enumerate( policy["assign"] ):
device = devices[device_index]
buffered_modules = []
buffered_size = device["free"]
# iterate through list of modules to compare against includes
for name, size in modules:
# doesn't pass policy
if not passes_policy( {"include": includes}, name ):
continue
# check if within budget
if buffered_size - size >= 0:
# add to buffer
buffered_modules.append( (name, size) )
buffered_size -= size
# budget exceeded, flush buffer
else:
discarded += buffered_modules
buffered_modules = []
buffered_size = 0
break
if buffered_modules and buffered_size:
device["modules"] += [ name for name, size in buffered_modules ]
device["free"] = buffered_size
modules = discarded
device_index = 0
module_index = 0
# assign modules to each device
while module_index < len(modules) and device_index < len(devices):
device = devices[device_index]
name, size = modules[module_index]
# fits within budget
if device["free"] - size >= 0:
device["modules"].append( name )
device["free"] -= size
module_index += 1
# does not fit in budget, increase device index
else:
device_index += 1
_logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
# to-do: check that all modules are exhausted
assert module_index >= len(modules)
# only return devices with modules assigned
return [ device for device in devices if device["modules"] ]
# handles naively splitting a model's layers across multiple devices
# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
def offload_model( model, policy=None ):
policy = get_model_offload_policy(model, policy=policy)
# move modules to respective devices
for i, device in enumerate( policy ):
# nothing assigned, skip
if not device["modules"]:
continue
for name in device["modules"]:
module = model.get_submodule(name)
module = module.to( device["name"] )
module.device = device['name']
# wrap modules with forward to ensure all inputs are matched to its device
for name, module in model.named_modules():
if not hasattr( module, 'forward' ):
continue
module.forward = auto_align_inputs_forward(module)
"""
# Validate that the layers are all in the right spot
for name, module in model.named_modules():
if not not [*module.named_children()]:
continue
try:
_logger.info( name, next(module.parameters()).device )
except Exception as e:
_logger.info( name, "?" )
pass
"""
return model

View File

@ -1,20 +1,39 @@
from contextlib import contextmanager from contextlib import contextmanager
import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import logging
from ..config import cfg from ..config import cfg
_logger = logging.getLogger(__name__)
Embedding = torch.nn.Embedding Embedding = torch.nn.Embedding
Linear = torch.nn.Linear Linear = torch.nn.Linear
if cfg.bitsandbytes.enabled: Adam = torch.optim.Adam
import bitsandbytes as bnb AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
if cfg.bitsandbytes.linear: Adagrad = torch.optim.Adagrad
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.embedding: # https://github.com/kyegomez/BitNet
Embedding = bnb.nn.StableEmbedding if cfg.optimizations.bitnet:
from bitnet import BitLinear
if cfg.optimizations.bitsandbytes:
import bitsandbytes as bnb
if cfg.optimizations.linear:
if cfg.optimizations.bitnet:
Linear = BitLinear
else:
Linear = bnb.nn.Linear8bitLt
if cfg.optimizations.embedding:
Embedding = bnb.nn.modules.Embedding
"""
Embedding.forward = lambda self, input: ( self.norm(F.embedding( Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input, input,
self.weight, self.weight,
@ -24,52 +43,101 @@ if cfg.bitsandbytes.enabled:
self.scale_grad_by_freq, self.scale_grad_by_freq,
self.sparse, self.sparse,
)).to(self.weight.dtype) ) )).to(self.weight.dtype) )
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
if cfg.bitsandbytes.enabled:
import bitsandbytes as bnb
Adam = bnb.optim.Adam
AdamW = bnb.optim.AdamW
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
""" """
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
return func( self, input.to(torch.int32), *args, **kwargs )
return func( self, input, *args, **kwargs )
"""
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: if cfg.optimizations.optimizers:
torch.nn.Linear = Linear Adam = bnb.optim.Adam8bit
torch.nn.Embedding = Embedding AdamW = bnb.optim.AdamW8bit
SGD = bnb.optim.SGD8bit
Adagrad = bnb.optim.Adagrad8bit
torch.optim.Adam = Adam elif cfg.optimizations.dadaptation:
torch.optim.AdamW = AdamW import dadaptation
if cfg.optimizations.optimizers:
Adam = dadaptation.DAdaptAdam
AdamW = dadaptation.DAdaptAdam
SGD = dadaptation.DAdaptSGD
AdaGrad = dadaptation.DAdaptAdaGrad
if cfg.optimizations.fp8:
import transformer_engine.pytorch as te
Linear = te.Linear
@contextmanager
def autocast():
yield te.fp8_autocast(enabled=True)
else:
@contextmanager
def autocast():
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
if cfg.optimizations.injects:
if cfg.optimizations.linear:
torch.nn.Linear = Linear
if cfg.optimizations.embedding:
torch.nn.Embedding = Embedding
if cfg.optimizations.optimizers:
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
AVAILABLE_COMPILE_BACKENDS = []
try:
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
except Exception as e:
pass
if cfg.optimizations.tensorrt:
try:
import torch_tensorrt
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
except Exception as e:
_logger.warning(f'Error while importing TensorRT: {str(e)}')
pass
def compile_model(model, backend="auto"):
if not backend or backend == "auto":
backend = AVAILABLE_COMPILE_BACKENDS[0]
if backend not in AVAILABLE_COMPILE_BACKENDS:
return torch.compile(model)
return torch.compile(model, backend=backend)
# https://github.com/konstmish/prodigy
try:
from prodigyopt import Prodigy
except Exception as e:
_logger.warning(f'Error while importing Prodigyopt: {str(e)}')
pass
# https://github.com/facebookresearch/schedule_free/
try:
import schedulefree
except Exception as e:
_logger.warning(f'Error while importing Schedule_Free: {str(e)}')
pass
# backwards compat
from .utils import (
autocast_forward,
replace_linear as replace_linear_old,
replace_embedding as replace_embedding_old,
replace_attention,
resize_weight,
offload_model,
)
# wrapped here so we can maintain default args
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
return replace_linear_old( model, klass, target, verbose )
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
return replace_embedding_old( model, klass, target, verbose )
Embedding.forward = autocast_forward(Embedding.forward)

220
image_classifier/webui.py Normal file
View File

@ -0,0 +1,220 @@
import os
import re
import argparse
import random
import tempfile
import functools
from datetime import datetime
import gradio as gr
from time import perf_counter
from pathlib import Path
from PIL import Image
from .inference import Classifier, cfg
from .train import train
from .utils import get_devices
classifier = None
layout = {}
layout["inference"] = {}
layout["training"] = {}
layout["settings"] = {}
for k in layout.keys():
layout[k]["inputs"] = { "progress": None }
layout[k]["outputs"] = {}
layout[k]["buttons"] = {}
# there's got to be a better way to go about this
def gradio_wrapper(inputs):
def decorated(fun):
@functools.wraps(fun)
def wrapped_function(*args, **kwargs):
for i, key in enumerate(inputs):
kwargs[key] = args[i]
try:
return fun(**kwargs)
except Exception as e:
raise gr.Error(str(e))
return wrapped_function
return decorated
class timer:
def __init__(self, msg="Elapsed time:"):
self.msg = msg
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
gr.Info(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
# returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./data/"), Path("./training/"), Path("./models/")] ):
yamls = []
for path in paths:
if not path.exists():
continue
for yaml in path.glob("**/*.yaml"):
if "/logs/" in str(yaml):
continue
yamls.append( yaml )
return yamls
def get_dtypes():
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
def load_model( yaml, device, dtype ):
gr.Info(f"Loading: {yaml}")
try:
init_classifier( yaml=Path(yaml), restart=True, device=device, dtype=dtype )
except Exception as e:
raise gr.Error(e)
gr.Info(f"Loaded model")
def init_classifier(yaml=None, restart=False, device="cuda", dtype="auto"):
global classifier
if classifier is not None:
if not restart:
return classifier
del classifier
classifier = None
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default=device)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=dtype)
args, unknown = parser.parse_known_args()
classifier = Classifier( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return classifier
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.yaml_path:
raise Exception("No YAML loaded.")
parser = argparse.ArgumentParser(allow_abbrev=False)
# I'm very sure I can procedurally generate this list
parser.add_argument("--image", type=str, default=kwargs["image"])
parser.add_argument("--temp", type=float, default=kwargs["temp"])
args, unknown = parser.parse_known_args()
classifier = init_classifier()
args.image = Image.open(args.image).convert('RGB')
gr.Info("Inferencing...")
with timer("Inferenced in") as t:
answer = classifier.inference(
image=args.image,
temperature=args.temp,
)
return answer
# setup args
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true")
parser.add_argument("--render_markdown", action="store_true", default="CLASSIFIER_YAML" in os.environ)
args, unknown = parser.parse_known_args()
args.listen_host = None
args.listen_port = None
args.listen_path = None
if args.listen:
try:
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
args.listen_port = match[1] if match[1] != "" else None
args.listen_path = match[2] if match[2] != "" else "/"
except Exception as e:
pass
if args.listen_port is not None:
args.listen_port = int(args.listen_port)
if args.listen_port == 0:
args.listen_port = None
# setup gradio
ui = gr.Blocks()
with ui:
with gr.Tab("Inference"):
with gr.Row():
with gr.Column(scale=4):
layout["inference"]["inputs"]["image"] = gr.Image(label="Input Image", sources=["upload"], type="filepath")
layout["inference"]["outputs"]["output"] = gr.Textbox(label="Output")
with gr.Column(scale=4):
with gr.Row():
layout["inference"]["inputs"]["temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature", info="Modifies the randomness from the samples. (0 to greedy sample)")
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
layout["inference"]["buttons"]["inference"].click(
fn=do_inference,
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
)
"""
with gr.Tab("Training"):
with gr.Row():
with gr.Column(scale=1):
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
with gr.Row():
with gr.Column(scale=1):
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
layout["training"]["buttons"]["train"].click(
fn=do_training,
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
)
"""
with gr.Tab("Settings"):
with gr.Row():
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
layout["settings"]["buttons"]["load"].click(
fn=load_model,
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
)
if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read()
# remove HF's metadata
if md.startswith("---\n"):
md = "".join(md.split("---")[2:])
gr.Markdown(md)
def start( lock=True ):
ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
if __name__ == "__main__":
start()

0
scripts/run.sh Executable file → Normal file
View File

View File

@ -1,5 +1,5 @@
import subprocess import subprocess
import sys
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
from setuptools import setup, find_packages from setuptools import setup, find_packages
@ -8,7 +8,6 @@ def shell(*args):
out = subprocess.check_output(args) out = subprocess.check_output(args)
return out.decode("ascii").strip() return out.decode("ascii").strip()
def write_version(version_core, pre_release=True): def write_version(version_core, pre_release=True):
if pre_release: if pre_release:
time = shell("git", "log", "-1", "--format=%cd", "--date=iso") time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
@ -23,8 +22,7 @@ def write_version(version_core, pre_release=True):
return version return version
with open("README.md", "r") as f:
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
setup( setup(
@ -37,17 +35,37 @@ setup(
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=(
# training backends
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else [])
+ [
# logging niceties
"coloredlogs>=15.0.1", "coloredlogs>=15.0.1",
"humanize>=4.4.0",
"matplotlib>=3.6.0",
"pandas>=1.5.0",
# boiler plate niceties
"diskcache>=5.4.0", "diskcache>=5.4.0",
"einops>=0.6.0", "einops>=0.6.0",
"omegaconf==2.0.6", "tqdm",
"tqdm>=4.64.1",
"humanize>=4.4.0",
"pandas>=1.5.0", # HF bloat
"tokenizers",
"transformers",
"safetensors",
# training bloat
"h5py",
"prodigyopt @ git+https://github.com/konstmish/prodigy",
# practically the reason to use python
"numpy",
"torch>=1.13.0", "torch>=1.13.0",
"torchmetrics", "torchmetrics",
"simple_http_server",
"pillow"
], ],
url="https://git.ecker.tech/mrq/resnet-classifier", url="https://git.ecker.tech/mrq/resnet-classifier",
) )