updated framework to use the saner framework that mrq/vall-e uses these days
This commit is contained in:
parent
5cb28a210e
commit
5610bb3bb3
0
.gitignore
vendored
Executable file → Normal file
0
.gitignore
vendored
Executable file → Normal file
40
README.md
Executable file → Normal file
40
README.md
Executable file → Normal file
|
@ -1,10 +1,10 @@
|
|||
# Tentative Title For A ResNet-Based Image Classifier
|
||||
|
||||
This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
|
||||
This is a simple ResNet based image classifier for images, using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
|
||||
|
||||
## Premise
|
||||
|
||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
|
||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I faced.
|
||||
|
||||
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
|
||||
|
||||
|
@ -16,44 +16,14 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
|
|||
|
||||
3. Install using `pip3 install -e ./image_classifier/`.
|
||||
|
||||
4. Train using `python3 -m image_classifier.train yaml='./data/config.yaml'`.
|
||||
4. Train using `python3 -m image_classifier.train --yaml='./data/config.yaml'`.
|
||||
|
||||
5. Wait.
|
||||
|
||||
## Inferencing
|
||||
|
||||
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
||||
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" --yaml="./data/config.yaml"`
|
||||
|
||||
### Continuous Usage
|
||||
|
||||
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
|
||||
|
||||
## Known Issues
|
||||
|
||||
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
|
||||
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
|
||||
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
|
||||
|
||||
## Strawmen
|
||||
|
||||
>\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!!
|
||||
|
||||
I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests.
|
||||
|
||||
>\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]`
|
||||
|
||||
Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader).
|
||||
|
||||
Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there.
|
||||
|
||||
>\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`??
|
||||
|
||||
I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment.
|
||||
|
||||
>\> UGH!!! What are you talking about """specific images"""???
|
||||
|
||||
[ひみつ](https://files.catbox.moe/csuh49.webm)
|
||||
|
||||
>\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`???
|
||||
|
||||
:)
|
||||
If you're looking to continuously classify images, use `python3 -m image_classifier --listen --port=7860 --yaml="./data/config.yaml"` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
|
121
data/config.yaml
Executable file → Normal file
121
data/config.yaml
Executable file → Normal file
|
@ -1,85 +1,84 @@
|
|||
dataset:
|
||||
training: [
|
||||
"./data/images/"
|
||||
]
|
||||
|
||||
validation: []
|
||||
|
||||
use_hdf5: False
|
||||
|
||||
workers: 0
|
||||
cache: True
|
||||
weights_format: sft
|
||||
|
||||
models:
|
||||
_models:
|
||||
- name: "classifier"
|
||||
tokens: 0
|
||||
len: 6
|
||||
dim: 512
|
||||
resnet: 34
|
||||
#loras:
|
||||
#- name : "lora"
|
||||
# rank: 128
|
||||
# alpha: 128
|
||||
# training: True
|
||||
# rvq_levels: []
|
||||
|
||||
hyperparameters:
|
||||
batch_size: 256
|
||||
gradient_accumulation_steps: 64
|
||||
gradient_clipping: 100
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps: 10
|
||||
|
||||
optimizer: Adamw
|
||||
learning_rate: 1.0e-3
|
||||
optimizer: Prodigy
|
||||
learning_rate: 1.0
|
||||
torch_optimizer: True
|
||||
|
||||
scheduler_type: ""
|
||||
#scheduler_type: OneCycle
|
||||
#scheduler_params:
|
||||
# cycle_first_step_size: 10_000
|
||||
# cycle_first_stair_count: 10_000
|
||||
|
||||
# cycle_second_step_size: 15_000
|
||||
# cycle_second_stair_count: 15_000
|
||||
|
||||
# decay_step_size: 5_000
|
||||
|
||||
# cycle_min_lr: 2.5e-4 # 1.0e-5
|
||||
# cycle_max_lr: 2.5e-4 # 1.0e-4
|
||||
# decay_lr_rate: 0.0
|
||||
|
||||
# cycle_min_mom: 0.90
|
||||
# cycle_max_mom: 0.99
|
||||
# decay_mom_rate: 0.0
|
||||
scheduler: "" # ScheduleFree
|
||||
torch_scheduler: True
|
||||
|
||||
evaluation:
|
||||
batch_size: 32
|
||||
frequency: 250
|
||||
size: 32
|
||||
batch_size: 64
|
||||
frequency: 100
|
||||
size: 64
|
||||
|
||||
steps: 300
|
||||
temperature: 1.0
|
||||
steps: 450
|
||||
temperature: 0.0
|
||||
|
||||
trainer:
|
||||
iterations: 100_000
|
||||
|
||||
save_tag: step
|
||||
save_on_oom: True
|
||||
save_on_quit: True
|
||||
iterations: 1_000_000
|
||||
save_frequency: 100
|
||||
|
||||
aggressive_optimizations: False
|
||||
keep_last_checkpoints: 32
|
||||
|
||||
check_for_oom: False
|
||||
gradient_checkpointing: True
|
||||
|
||||
#load_tag: "9500"
|
||||
#load_state_dict: True
|
||||
#load_states: False
|
||||
#strict_loading: False
|
||||
#restart_step_count: True
|
||||
weight_dtype: bfloat16
|
||||
amp: True
|
||||
|
||||
gc_mode: None # "global_step"
|
||||
|
||||
weight_dtype: float32
|
||||
|
||||
backend: local
|
||||
backend: deepspeed
|
||||
deepspeed:
|
||||
zero_optimization_level: 0
|
||||
use_compression_training: True
|
||||
inferencing: False
|
||||
amp: False
|
||||
|
||||
inference:
|
||||
use_vocos: True
|
||||
backend: local
|
||||
|
||||
bitsandbytes:
|
||||
enabled: false
|
||||
weight_dtype: bfloat16
|
||||
amp: True
|
||||
|
||||
optimizations:
|
||||
injects: False
|
||||
replace: True
|
||||
|
||||
linear: False
|
||||
embedding: False
|
||||
optimizers: True
|
||||
|
||||
bitsandbytes: False
|
||||
dadaptation: False
|
||||
bitnet: False
|
||||
fp8: False
|
||||
|
||||
dataset:
|
||||
use_hdf5: True
|
||||
hdf5_flag: r
|
||||
|
||||
workers: 1
|
||||
cache: True
|
||||
|
||||
training: [
|
||||
"./data/images/"
|
||||
]
|
||||
validation: [
|
||||
"./data/validation/"
|
||||
]
|
|
@ -12,35 +12,55 @@ def main():
|
|||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--listen", action='store_true')
|
||||
parser.add_argument("--port", type=int, default=9090)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--ckpt", type=Path, default=None)
|
||||
parser.add_argument("--temp", type=float, default=1.0)
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=None)
|
||||
|
||||
parser.add_argument("--temp", type=float, default=0.0)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
|
||||
classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
if args.listen:
|
||||
@route("/")
|
||||
def inference( b64, temperature=1.0 ):
|
||||
def inference( b64, temperature=args.temp ):
|
||||
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
|
||||
return { "answer": classifier.inference( image=image, temperature=args.temp ) }
|
||||
return { "answer": classifier.inference( image=image, temperature=temperature ) }
|
||||
server.start(port=args.port)
|
||||
else:
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--path", type=Path)
|
||||
parser.add_argument("--base64", type=str)
|
||||
parser.add_argument("--write", type=Path)
|
||||
parser.add_argument("--temp", type=float, default=1.0)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
images = []
|
||||
if args.path:
|
||||
if args.path.is_dir():
|
||||
for p in args.path.rglob("./*.jpg"):
|
||||
image = Image.open(p).convert('RGB')
|
||||
images.append(image)
|
||||
for p in args.path.rglob("./*.png"):
|
||||
image = Image.open(p).convert('RGB')
|
||||
images.append(image)
|
||||
else:
|
||||
image = Image.open(args.path).convert('RGB')
|
||||
images.append(image)
|
||||
elif args.base64:
|
||||
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
|
||||
images.append(image)
|
||||
else:
|
||||
raise "Specify a --path or --base64."
|
||||
|
||||
for image in images:
|
||||
answer = classifier.inference( image=image, temperature=args.temp )
|
||||
print("Answer:", answer)
|
||||
if args.write:
|
||||
args.write.mkdir(exist_ok=True)
|
||||
image.save( args.write / f"{answer}.jpg")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -6,31 +6,61 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from functools import cached_property, cache
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
import argparse
|
||||
import yaml
|
||||
import random
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
|
||||
from .utils.distributed import world_size
|
||||
|
||||
|
||||
def set_seed(seed=None):
|
||||
if not seed:
|
||||
seed = time.time()
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@dataclass()
|
||||
class _Config:
|
||||
cfg_path: str | None = None
|
||||
class BaseConfig:
|
||||
yaml_path: str | None = None # path passed in through --yaml
|
||||
|
||||
@property
|
||||
def relpath(self):
|
||||
def cfg_path(self):
|
||||
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
|
||||
|
||||
@property
|
||||
def rel_path(self):
|
||||
return Path(self.cfg_path)
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return self.rel_path / ".cache"
|
||||
|
||||
@property
|
||||
def data_dir(self):
|
||||
return self.rel_path / "data"
|
||||
|
||||
@property
|
||||
def metadata_dir(self):
|
||||
return self.rel_path / "metadata"
|
||||
|
||||
@property
|
||||
def ckpt_dir(self):
|
||||
return self.relpath / "ckpt"
|
||||
return self.rel_path / "ckpt"
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
return self.relpath / "logs" / str(self.start_time)
|
||||
return self.rel_path / "logs" / str(self.start_time)
|
||||
|
||||
@cached_property
|
||||
def start_time(self):
|
||||
|
@ -64,39 +94,28 @@ class _Config:
|
|||
with open(path, "w") as f:
|
||||
f.write(self.dumps())
|
||||
|
||||
@staticmethod
|
||||
def _is_cfg_argv(s):
|
||||
return "=" in s and "--" not in s
|
||||
|
||||
@classmethod
|
||||
def from_yaml( cls, yaml_path ):
|
||||
return cls.from_cli( [f'yaml="{yaml_path}"'] )
|
||||
state = {}
|
||||
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
||||
state.setdefault("yaml_path", yaml_path)
|
||||
return cls(**state)
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, args=sys.argv):
|
||||
cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)])
|
||||
# legacy support for yaml=`` format
|
||||
for i, arg in enumerate(args):
|
||||
if arg.startswith("yaml"):
|
||||
args[i] = f'--{arg}'
|
||||
|
||||
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
|
||||
sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)]
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
args, unknown = parser.parse_known_args(args=args)
|
||||
|
||||
if cli_cfg.get("help"):
|
||||
print(f"Configurable hyperparameters with their default values:")
|
||||
print(json.dumps(asdict(cls()), indent=2, default=str))
|
||||
exit()
|
||||
if args.yaml:
|
||||
return cls.from_yaml( args.yaml )
|
||||
|
||||
if "yaml" in cli_cfg:
|
||||
yaml_cfg = OmegaConf.load(cli_cfg.yaml)
|
||||
yaml_path = Path(cli_cfg.yaml).absolute()
|
||||
cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1])
|
||||
cfg_path = cfg_path.with_suffix("")
|
||||
cfg_path = f'./{cfg_path}'
|
||||
|
||||
yaml_cfg.setdefault("cfg_path", cfg_path)
|
||||
cli_cfg.pop("yaml")
|
||||
else:
|
||||
yaml_cfg = {}
|
||||
merged = OmegaConf.merge(yaml_cfg, cli_cfg)
|
||||
return cls(**dict(merged))
|
||||
return cls(**{})
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
@ -106,104 +125,195 @@ class _Config:
|
|||
|
||||
@dataclass()
|
||||
class Dataset:
|
||||
training: list[Path] = field(default_factory=lambda: [])
|
||||
validation: list[Path] = field(default_factory=lambda: [])
|
||||
training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
|
||||
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
|
||||
|
||||
temp: list[Path] = field(default_factory=lambda: [])
|
||||
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
|
||||
use_hdf5: bool = False # whether to load from an HDF5 dataset
|
||||
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
|
||||
|
||||
# de-implemented, because the data isn't that large to facilitate HDF5
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
|
||||
workers: int = 8
|
||||
cache: bool = True
|
||||
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
||||
workers: int = 8 # number of dataloader workers to spawn
|
||||
cache: bool = True # use diskcache to cache the dataset
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
name: str = ""
|
||||
name: str = "classifier"
|
||||
|
||||
tokens: int = 0 # number of token types
|
||||
len: int = 1 # how long a sequence can be
|
||||
dim: int = 512
|
||||
resnet: int = 18
|
||||
|
||||
width: int = 300
|
||||
height: int = 80
|
||||
|
||||
version: int = 1
|
||||
training: bool = True
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
return self.name
|
||||
|
||||
@dataclass()
|
||||
class Models:
|
||||
_models: list[Model] = field(default_factory=lambda: [
|
||||
Model(name="captcha"),
|
||||
])
|
||||
|
||||
def get(self, name=None):
|
||||
if not name:
|
||||
return [ Model(**model) for model in self._models ]
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
||||
for model in self._models:
|
||||
if model.name == name:
|
||||
return model
|
||||
def loss_factor(self, k):
|
||||
return self.loss_factors[k] if k in self.loss_factors else 1.0
|
||||
|
||||
raise ValueError
|
||||
@property
|
||||
# required for fp8 as the lengths needs to be divisible by 8
|
||||
def input_alignment(self):
|
||||
return 8 if cfg.optimizations.fp8 else 0
|
||||
|
||||
@property
|
||||
def activation_checkpointing(self):
|
||||
return cfg.trainer.activation_checkpointing
|
||||
|
||||
@property
|
||||
def gradient_checkpointing(self):
|
||||
return cfg.trainer.gradient_checkpointing
|
||||
|
||||
@property
|
||||
def lora_policy(self):
|
||||
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
|
||||
exclude = []
|
||||
|
||||
if self.arch_type == "llama":
|
||||
include = ["self_attn", "mlp"] # target only the attention + mlp
|
||||
exclude = ["self_attn.k_proj"] # common literature says to ignore it
|
||||
if self.arch_type == "retnet":
|
||||
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
|
||||
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
|
||||
|
||||
return dict(include=include, exclude=exclude)
|
||||
|
||||
# should be renamed to Adapters
|
||||
@dataclass()
|
||||
class LoRA:
|
||||
name: str = "lora" # vanity name
|
||||
# to-do: find sane default values
|
||||
rank: int = 128 # rank for the LoRA
|
||||
alpha: int = 128 # rank for the LoRA
|
||||
training: bool = True #
|
||||
embeddings: bool = False # train the embedding too
|
||||
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
|
||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
||||
return "-".join(name)
|
||||
|
||||
# actually not needed anymore
|
||||
def active_level( self, level ):
|
||||
if not self.rvq_levels:
|
||||
return True
|
||||
return level in self.rvq_levels
|
||||
|
||||
@dataclass()
|
||||
class Hyperparameters:
|
||||
batch_size: int = 8
|
||||
gradient_accumulation_steps: int = 32
|
||||
gradient_clipping: int = 100 # to be implemented in the local backend
|
||||
batch_size: int = 8 # number of samples per training batch
|
||||
gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
|
||||
gradient_clipping: int | float = 10 # largest size a gradient norm can be
|
||||
|
||||
optimizer: str = "Adamw"
|
||||
learning_rate: float = 3.25e-4
|
||||
optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
|
||||
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
scheduler_type: str = "" # to be implemented in the local backend
|
||||
scheduler_params: dict = field(default_factory=lambda: {})
|
||||
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
|
||||
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
|
||||
|
||||
scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
|
||||
scheduler_type: str = "" # deprecated
|
||||
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
autotune: bool = False # to do deepspeed's autotuning
|
||||
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
|
||||
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
|
||||
|
||||
@dataclass()
|
||||
class Evaluation:
|
||||
batch_size: int = 64
|
||||
frequency: int = 250
|
||||
size: int = 64
|
||||
batch_size: int = 64 # number of samples per batch during eval / val
|
||||
frequency: int = 250 # do eval / val every X iterations
|
||||
size: int = 64 # number of samples to generate during eval / val
|
||||
|
||||
steps: int = 500
|
||||
temperature: float = 1.0
|
||||
temperature: float = 1.0 # AR temp for inferencing
|
||||
|
||||
load_disabled_engines: bool = True # see the other load_disabled_engines
|
||||
|
||||
@dataclass()
|
||||
class DeepSpeed:
|
||||
zero_optimization_level: int = 0
|
||||
use_compression_training: bool = False
|
||||
zero_optimization_level: int = 0 # doesn't seem to work
|
||||
use_compression_training: bool = False # cope
|
||||
compression_bits: int = 8 # cope
|
||||
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||
|
||||
def get_ds_cfg(self, model):
|
||||
weights = [ name[0] for name in model.named_parameters() ]
|
||||
bits = 8
|
||||
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
||||
|
||||
scheduler_params = {}
|
||||
for k in cfg.hyperparameters.scheduler_params:
|
||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
|
||||
@cached_property
|
||||
def ds_cfg(self):
|
||||
optimizer_params = cfg.hyperparameters.optimizer_params
|
||||
|
||||
if 'lr' not in optimizer_params:
|
||||
optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
|
||||
|
||||
scheduler_params = cfg.hyperparameters.scheduler_params
|
||||
if 'warmup_num_steps' not in scheduler_params:
|
||||
scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
|
||||
|
||||
if 'total_num_steps' not in scheduler_params:
|
||||
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
||||
|
||||
autotune_params = cfg.hyperparameters.autotune_params
|
||||
|
||||
if "enabled" not in autotune_params:
|
||||
autotune_params['enabled'] = True
|
||||
|
||||
if "results_dir" not in autotune_params:
|
||||
autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" )
|
||||
|
||||
if "exps_dir" not in autotune_params:
|
||||
autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" )
|
||||
|
||||
# DeepSpeed fp16 is incompatible with its AMP
|
||||
if cfg.trainer.weight_dtype.lower() == "float16":
|
||||
self.amp = False
|
||||
|
||||
# disable local AMP
|
||||
if self.amp:
|
||||
cfg.trainer.amp = False
|
||||
|
||||
ds_cfg = {
|
||||
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
|
||||
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
|
||||
"optimizer": {
|
||||
"type": cfg.hyperparameters.optimizer,
|
||||
"params": {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
},
|
||||
"params": optimizer_params,
|
||||
} if not cfg.hyperparameters.torch_optimizer else None,
|
||||
"scheduler": {
|
||||
"type": cfg.hyperparameters.scheduler_type,
|
||||
"type": cfg.hyperparameters.scheduler,
|
||||
"params": scheduler_params,
|
||||
} if cfg.hyperparameters.scheduler_type != "" else None,
|
||||
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"auto_cast": True,
|
||||
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
|
||||
"bf16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
||||
"auto_cast": True, # ???
|
||||
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
||||
},
|
||||
"amp": {
|
||||
"enabled": self.amp,
|
||||
},
|
||||
"autotuning": autotune_params if cfg.hyperparameters.autotune else None,
|
||||
"compression_training": {
|
||||
"weight_quantization": {
|
||||
"shared_parameters":{
|
||||
|
@ -214,7 +324,7 @@ class DeepSpeed:
|
|||
"quantize_verbose": True,
|
||||
"quantization_type": "symmetric",
|
||||
"rounding": "nearest",
|
||||
"quantize_weight_in_forward": True,
|
||||
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
||||
"fp16_mixed_quantize":{
|
||||
"enabled": False,
|
||||
"quantize_change_ratio": 1
|
||||
|
@ -223,30 +333,38 @@ class DeepSpeed:
|
|||
"different_groups": {
|
||||
"wq1": {
|
||||
"params": {
|
||||
"start_bits": bits,
|
||||
"target_bits": bits,
|
||||
"start_bits": self.compression_bits,
|
||||
"target_bits": self.compression_bits,
|
||||
"quantization_period": 0
|
||||
},
|
||||
"modules": weights
|
||||
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
|
||||
}
|
||||
}
|
||||
},
|
||||
"activation_quantization": {
|
||||
"shared_parameters":{
|
||||
"enabled": True,
|
||||
"quantizer_kernel": True,
|
||||
"schedule_offset": 0,
|
||||
"quantize_groups": 64,
|
||||
"quantize_verbose": True,
|
||||
"quantization_type": "symmetric",
|
||||
"range_calibration": "dynamic",
|
||||
"schedule_offset": 0
|
||||
"rounding": "nearest",
|
||||
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
||||
"fp16_mixed_quantize":{
|
||||
"enabled": False,
|
||||
"quantize_change_ratio": 1
|
||||
}
|
||||
},
|
||||
"different_groups": {
|
||||
"aq1": {
|
||||
"params": {
|
||||
"bits": bits
|
||||
"bits": self.compression_bits,
|
||||
},
|
||||
"modules": weights
|
||||
}
|
||||
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
|
||||
}
|
||||
}
|
||||
},
|
||||
} if self.use_compression_training else None,
|
||||
"zero_optimization": {
|
||||
"stage": self.zero_optimization_level,
|
||||
|
@ -264,7 +382,10 @@ class DeepSpeed:
|
|||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": True
|
||||
}
|
||||
},
|
||||
"zero_quantized_weights": self.use_compression_training,
|
||||
"zero_hpz_partition_size": world_size(),
|
||||
"zero_quantized_gradients": self.use_compression_training,
|
||||
} if self.zero_optimization_level > 0 else None,
|
||||
"comms_logger": {
|
||||
"enabled": False
|
||||
|
@ -275,113 +396,314 @@ class DeepSpeed:
|
|||
for k in null_keys:
|
||||
del ds_cfg[k]
|
||||
|
||||
if os.path.exists("./config/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
|
||||
if os.path.exists("./data/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
||||
else:
|
||||
ds_cfg.update(self.config)
|
||||
|
||||
return ds_cfg
|
||||
|
||||
@dataclass()
|
||||
class Trainer:
|
||||
iterations: int = 100_000
|
||||
iterations: int = 1_000_000 # maximum iterations to train
|
||||
|
||||
save_tag: str = "step"
|
||||
load_tag: str | None = None
|
||||
save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
|
||||
load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
|
||||
|
||||
save_on_oom: bool = True
|
||||
save_on_quit: bool = True
|
||||
save_frequency: int = 100
|
||||
save_on_oom: bool = True # save if an OOM error is raised
|
||||
save_on_quit: bool = True # save when quitting training
|
||||
|
||||
load_state_dict: bool = False
|
||||
load_states: bool = True
|
||||
strict_loading: bool = True
|
||||
restart_step_count: bool = False
|
||||
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
|
||||
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
|
||||
|
||||
aggressive_optimizations: bool = False
|
||||
check_for_oom: bool = True
|
||||
save_frequency: int = 100 # frequency to save every X iterations
|
||||
|
||||
gc_mode: str | None = None
|
||||
keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
|
||||
|
||||
weight_dtype: str = "float16"
|
||||
load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
|
||||
load_states: bool = True #
|
||||
strict_loading: bool = False # sets strict_loading=True when loading the state dict
|
||||
load_module_only: bool = False #
|
||||
restart_step_count: bool = False # clears the training stats when loading a checkpoint
|
||||
resize_modules: bool = False # automatically resizes
|
||||
|
||||
backend: str = "deepspeed"
|
||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
||||
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
aggressive_optimizations: bool = False # deprecated
|
||||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
|
||||
|
||||
weight_dtype: str = "float16" # dtype to have the model under
|
||||
|
||||
amp: bool = False # automatic mixed precision
|
||||
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
||||
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
|
||||
|
||||
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
|
||||
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
|
||||
|
||||
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if cfg.trainer.weight_dtype == "bfloat16":
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if self.weight_dtype == "float8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
if self.weight_dtype == "float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
@cached_property
|
||||
def scale_loss(self):
|
||||
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
||||
return self.dtype == torch.float16
|
||||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
use_vocos: bool = True # artifact from the VALL-E trainer
|
||||
backend: str = "local" # backend to use when inferencing
|
||||
weight_dtype: str = "float32" # dtype to load the model under
|
||||
amp: bool = False # automatic mixed precision during inferencing
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if self.weight_dtype == "int8":
|
||||
return torch.int8
|
||||
if self.weight_dtype == "float8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
if self.weight_dtype == "float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
@dataclass()
|
||||
class BitsAndBytes:
|
||||
enabled: bool = False
|
||||
injects: bool = False
|
||||
class Optimizations:
|
||||
injects: bool = False # overwrites default torch classes (not recommended)
|
||||
replace: bool = False # replaces modules in place with the optimized version (recommended)
|
||||
compile: bool | str = False # runs torch.compile on the model
|
||||
|
||||
linear: bool = False
|
||||
embedding: bool = False
|
||||
linear: bool = True # inject/replace linear for BnB
|
||||
embedding: bool = True # inject/replace embedding for BnB
|
||||
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
||||
|
||||
bitsandbytes: bool = False # use bitsandbytes
|
||||
dadaptation: bool = False # use dadaptation optimizer
|
||||
bitnet: bool = False # use bitnet
|
||||
fp8: bool = False # use fp8
|
||||
|
||||
model_offloading: dict | None = None # automatically splits the model over a list of devices
|
||||
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
|
||||
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
||||
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
|
||||
|
||||
tensorrt: bool = False
|
||||
|
||||
@dataclass()
|
||||
class Config(_Config):
|
||||
device: str = "cuda"
|
||||
class Config(BaseConfig):
|
||||
device: str = "cuda" # target device
|
||||
mode: str = "training" # "inferencing"
|
||||
experimental: bool = False # Debug flag, unused now
|
||||
|
||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||
models: Models = field(default_factory=lambda: Models)
|
||||
models: dict | list | None = field(default_factory=lambda: [])
|
||||
loras: dict | list | None = field(default_factory=lambda: [])
|
||||
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
||||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||
inference: Inference = field(default_factory=lambda: Inference)
|
||||
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
||||
bitsandbytes: dict | list | None = None # deprecated
|
||||
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||
|
||||
def get_device(self):
|
||||
return torch.cuda.current_device() if self.device == "cuda" else self.device
|
||||
tokenizer: str | None = None # tokenizer class
|
||||
tokenizer_path: str = "./tokenizer.json" # tokenizer path
|
||||
|
||||
weights_format: str = "pth" # "pth" | "sft"
|
||||
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return ".cache" / self.relpath
|
||||
def model(self):
|
||||
for i, model in enumerate(self.models):
|
||||
if model.training:
|
||||
return model
|
||||
|
||||
return self.models[0] if len(self.models) > 0 else None
|
||||
|
||||
# should be renamed to adapters
|
||||
@property
|
||||
def lora(self):
|
||||
for i, lora in enumerate(self.loras):
|
||||
if lora.training:
|
||||
return lora
|
||||
|
||||
return self.loras[0] if len(self.loras) > 0 else None
|
||||
|
||||
@property
|
||||
def distributed(self):
|
||||
return world_size() > 1
|
||||
|
||||
@cached_property
|
||||
def diskcache(self):
|
||||
if self.dataset.cache:
|
||||
if self.yaml_path is not None and self.dataset.cache:
|
||||
return diskcache.Cache(self.cache_dir).memoize
|
||||
return lambda: lambda x: x
|
||||
|
||||
# I don't remember why this is needed
|
||||
def load_yaml( self, config_path ):
|
||||
tmp = Config.from_yaml( config_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_hdf5( self, write=False ):
|
||||
if hasattr(self, 'hdf5'):
|
||||
self.hdf5.close()
|
||||
|
||||
if self.distributed:
|
||||
self.dataset.hdf5_flag = "r"
|
||||
try:
|
||||
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||
except Exception as e:
|
||||
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
# to-do: prune unused keys
|
||||
def format( self, training=True ):
|
||||
if isinstance(self.dataset, type):
|
||||
self.dataset = dict()
|
||||
|
||||
if isinstance(self.models, type):
|
||||
self.models = dict()
|
||||
|
||||
if isinstance(self.loras, type):
|
||||
self.loras = dict()
|
||||
|
||||
if isinstance(self.hyperparameters, type):
|
||||
self.hyperparameters = dict()
|
||||
|
||||
if isinstance(self.evaluation, type):
|
||||
self.evaluation = dict()
|
||||
|
||||
if isinstance(self.trainer, type):
|
||||
self.trainer = dict()
|
||||
|
||||
if isinstance(self.inference, type):
|
||||
self.inference = dict()
|
||||
|
||||
if isinstance(self.optimizations, type):
|
||||
self.optimizations = dict()
|
||||
|
||||
self.dataset = Dataset(**self.dataset)
|
||||
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||
|
||||
self.models = [ Model(**model) for model in self.models ]
|
||||
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
||||
|
||||
if not self.models:
|
||||
self.models = [ Model() ]
|
||||
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
|
||||
self.trainer = Trainer(**self.trainer)
|
||||
|
||||
if not isinstance(self.trainer.deepspeed, type):
|
||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||
|
||||
self.inference = Inference(**self.inference)
|
||||
|
||||
if self.bitsandbytes is not None:
|
||||
self.optimizations = Optimizations(**self.bitsandbytes)
|
||||
else:
|
||||
self.optimizations = Optimizations(**self.optimizations)
|
||||
|
||||
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
||||
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
||||
self.hyperparameters.scheduler_type = ""
|
||||
|
||||
# do not combine the two
|
||||
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
|
||||
self.hyperparameters.scheduler = ""
|
||||
|
||||
if self.hyperparameters.scheduler == "":
|
||||
self.hyperparameters.torch_scheduler = True
|
||||
|
||||
if self.trainer.backend == "local" and self.distributed:
|
||||
self.trainer.ddp = True
|
||||
|
||||
if self.trainer.activation_checkpointing is not None:
|
||||
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
|
||||
|
||||
if not training:
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
# load our HDF5 file if requested here
|
||||
if self.dataset.use_hdf5:
|
||||
self.load_hdf5()
|
||||
|
||||
# load tokenizer
|
||||
if cfg.tokenizer == "naive":
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
else:
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
|
||||
if tokenizer_path and not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||
|
||||
if tokenizer_path and tokenizer_path.exists():
|
||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||
else:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
except Exception as e:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
|
||||
pass
|
||||
|
||||
|
||||
# Preserves the old behavior
|
||||
class NaiveTokenizer:
|
||||
def get_vocab( self ):
|
||||
"""
|
||||
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
||||
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
||||
"""
|
||||
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
|
||||
|
||||
@cached_property
|
||||
def _bos_token( self ):
|
||||
return self.get_vocab()["<s>"]
|
||||
|
||||
@cached_property
|
||||
def _eos_token( self ):
|
||||
return self.get_vocab()["</s>"]
|
||||
|
||||
def encode( self, s ):
|
||||
symmap = self.get_vocab()
|
||||
s = s.replace("O", "0")
|
||||
s = [f"<s>"] + [ p if p in symmap else " " for p in s ] + [f"</s>"]
|
||||
return [*map(symmap.get, s)]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
cfg = Config.from_cli()
|
||||
|
||||
# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves
|
||||
cfg.dataset = Dataset(**cfg.dataset)
|
||||
cfg.models = Models(**cfg.models)
|
||||
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
|
||||
cfg.evaluation = Evaluation(**cfg.evaluation)
|
||||
cfg.trainer = Trainer(**cfg.trainer)
|
||||
cfg.inference = Inference(**cfg.inference)
|
||||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
||||
|
||||
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
|
||||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
|
||||
try:
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
|
||||
cfg.format()
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
||||
if not cfg.dataset.use_hdf5:
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
_logger.error(f"Error while parsing config YAML: {str(e)}")
|
||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(cfg)
|
|
@ -1,16 +1,19 @@
|
|||
# todo: clean this mess up
|
||||
|
||||
import copy
|
||||
# import h5py
|
||||
import h5py
|
||||
import json
|
||||
import logging
|
||||
#import numpy as np
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import math
|
||||
import itertools
|
||||
|
||||
from .config import cfg
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
from .utils.io import torch_save, torch_load
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -20,23 +23,57 @@ from typing import Any
|
|||
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset as _Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@cache
|
||||
# to-do: clean up this symmap mess
|
||||
def get_symmap():
|
||||
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
|
||||
return cfg.tokenizer.get_vocab()
|
||||
|
||||
@cache
|
||||
def _get_symbols( content ):
|
||||
content = content.replace("O", "0")
|
||||
return [f"<s>"] + [ p for p in content ] + [f"</s>"]
|
||||
def tokenize( s ):
|
||||
if isinstance( s, list ):
|
||||
s = "".join( s )
|
||||
return cfg.tokenizer.encode( s )
|
||||
|
||||
"""
|
||||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def _get_hdf5_path(path):
|
||||
# to-do: better validation
|
||||
return str(path)
|
||||
|
||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
data_dir = str(data_dir)
|
||||
|
||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||
|
||||
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, validate=False ):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
groups = defaultdict(list)
|
||||
for e in l:
|
||||
groups[fn(e)].append(e)
|
||||
groups = {k: groups[k] for k in sorted(groups)}
|
||||
for interleaved in zip_longest(*groups.values()):
|
||||
for value in interleaved:
|
||||
if value is not None:
|
||||
yield value
|
||||
"""
|
||||
|
||||
class Dataset(_Dataset):
|
||||
def __init__(
|
||||
|
@ -44,43 +81,90 @@ class Dataset(_Dataset):
|
|||
paths,
|
||||
width=300,
|
||||
height=80,
|
||||
stacks=0,
|
||||
|
||||
symmap=get_symmap(),
|
||||
training=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._head = None
|
||||
|
||||
self.paths = paths
|
||||
self.sampler = None
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.stacks = stacks
|
||||
|
||||
self.paths = paths
|
||||
self.image_dtype = cfg.trainer.dtype
|
||||
self.symmap = symmap
|
||||
|
||||
self.training = training
|
||||
self.dataset_type = "training" if self.training else "validation"
|
||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
|
||||
transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
@cached_property
|
||||
def symbols(self):
|
||||
return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths]))
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
if len(self.dataset) == 0:
|
||||
self.dataset = cfg.dataset.training
|
||||
|
||||
# split dataset accordingly per GPU
|
||||
if cfg.distributed and self.training:
|
||||
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||
|
||||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
@cached_property
|
||||
def sampler_state_dict_path(self):
|
||||
return cfg.rel_path / f"sampler.rank{global_rank()}.pt"
|
||||
|
||||
def save_state_dict(self, path = None):
|
||||
"""
|
||||
if path is None:
|
||||
path = self.sampler_state_dict_path
|
||||
if self.sampler is not None:
|
||||
state_dict = self.sampler.get_state()
|
||||
elif self.samplers is not None:
|
||||
state_dict = {
|
||||
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
||||
}
|
||||
torch_save(state_dict, path)
|
||||
"""
|
||||
return
|
||||
|
||||
def load_state_dict(self, path = None):
|
||||
"""
|
||||
if path is None:
|
||||
path = self.sampler_state_dict_path
|
||||
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
state_dict = torch_load(path)
|
||||
if self.sampler is not None:
|
||||
state_dict = self.sampler.set_state(state_dict)
|
||||
else:
|
||||
for name, sampler in state_dict["samplers"].items():
|
||||
if name not in self.samplers:
|
||||
continue
|
||||
self.samplers[name].set_state( sampler )
|
||||
"""
|
||||
return
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
tokens = tokenize( path.stem.upper() )
|
||||
text = torch.tensor( tokens ).to(dtype=torch.uint8)
|
||||
|
||||
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
|
||||
try:
|
||||
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
|
||||
except Exception as e:
|
||||
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
|
||||
raise e
|
||||
image = Image.open(path).convert('RGB')
|
||||
width, height = image.size
|
||||
|
||||
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
|
||||
image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
|
@ -98,11 +182,6 @@ class Dataset(_Dataset):
|
|||
def __len__(self):
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
||||
def pin_memory(self):
|
||||
self.text = self.text.pin_memory()
|
||||
self.image = self.image.pin_memory()
|
||||
return self
|
||||
|
||||
|
||||
def collate_fn(samples: list[dict]):
|
||||
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
||||
|
@ -111,21 +190,28 @@ def collate_fn(samples: list[dict]):
|
|||
|
||||
def _seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
#np.random.seed(worker_seed)
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def _create_dataloader(dataset, training):
|
||||
kwargs = dict(
|
||||
shuffle=True,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
drop_last=training,
|
||||
sampler=dataset.sampler if training else None,
|
||||
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
||||
batch_sampler=dataset.sampler,
|
||||
)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
shuffle=True, # training
|
||||
drop_last=training,
|
||||
num_workers=cfg.dataset.workers,
|
||||
collate_fn=collate_fn,
|
||||
persistent_workers=cfg.dataset.workers > 0,
|
||||
pin_memory=False, # True,
|
||||
persistent_workers=cfg.dataset.workers > 1,
|
||||
pin_memory=False,
|
||||
worker_init_fn=_seed_worker,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _load_train_val_paths( val_ratio=0.1 ):
|
||||
|
@ -133,8 +219,8 @@ def _load_train_val_paths( val_ratio=0.1 ):
|
|||
train_paths = []
|
||||
val_paths = []
|
||||
|
||||
print(cfg.dataset.training)
|
||||
for data_dir in cfg.dataset.training:
|
||||
paths.extend(data_dir.rglob("*.jpg"))
|
||||
paths.extend(data_dir.rglob("*.png"))
|
||||
|
||||
if len(paths) > 0:
|
||||
|
@ -146,12 +232,13 @@ def _load_train_val_paths( val_ratio=0.1 ):
|
|||
val_len = math.floor(len(train_paths) * val_ratio)
|
||||
train_len = math.floor(len(train_paths) * (1 - val_ratio))
|
||||
|
||||
print(val_len, train_len)
|
||||
|
||||
val_paths = train_paths[:-val_len]
|
||||
train_paths = train_paths[:train_len]
|
||||
else:
|
||||
paths = []
|
||||
|
||||
for data_dir in cfg.dataset.validation:
|
||||
paths.extend(data_dir.rglob("*.jpg"))
|
||||
paths.extend(data_dir.rglob("*.png"))
|
||||
|
||||
if len(paths) > 0:
|
||||
|
@ -169,7 +256,6 @@ def _load_train_val_paths( val_ratio=0.1 ):
|
|||
|
||||
return train_paths, val_paths
|
||||
|
||||
@cfg.diskcache()
|
||||
def create_datasets():
|
||||
train_paths, val_paths = _load_train_val_paths()
|
||||
|
||||
|
@ -187,10 +273,10 @@ def create_datasets():
|
|||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
|
||||
# deepcopy is slow
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
subtrain_dataset.head_(cfg.evaluation.size)
|
||||
subtrain_dataset.training_(False)
|
||||
|
@ -200,8 +286,6 @@ def create_train_val_dataloader():
|
|||
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
|
||||
|
||||
_logger.info(str(train_dataset.symmap))
|
||||
|
||||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
||||
|
@ -210,11 +294,305 @@ def create_train_val_dataloader():
|
|||
|
||||
return train_dl, subtrain_dl, val_dl
|
||||
|
||||
# parse dataset into better to sample metadata
|
||||
"""
|
||||
def create_dataset_metadata( skip_existing=True ):
|
||||
symmap = get_symmap()
|
||||
|
||||
root = str(cfg.data_dir)
|
||||
metadata_root = str(cfg.metadata_dir)
|
||||
|
||||
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
name = str(dir)
|
||||
name = name.replace(root, "")
|
||||
|
||||
speaker_name = name
|
||||
|
||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
||||
except Exception as e:
|
||||
metadata = {}
|
||||
|
||||
if not os.path.isdir(f'{root}/{name}/'):
|
||||
return
|
||||
|
||||
# tqdm.write(f'{root}/{name}')
|
||||
files = os.listdir(f'{root}/{name}/')
|
||||
|
||||
# grab IDs for every file
|
||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
||||
|
||||
wrote = False
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
|
||||
|
||||
if audios and not quant_path.exists():
|
||||
continue
|
||||
|
||||
key = f'{type}/{speaker_name}/{id}'
|
||||
|
||||
if skip_existing and id in metadata:
|
||||
continue
|
||||
|
||||
wrote = True
|
||||
|
||||
if id not in metadata:
|
||||
metadata[id] = {}
|
||||
|
||||
utterance_metadata = {}
|
||||
if audios:
|
||||
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
||||
dac = np.load(quant_path, allow_pickle=True)[()]
|
||||
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
|
||||
if "text" in dac["metadata"]:
|
||||
utterance_metadata["text"] = dac["metadata"]["text"]
|
||||
if "phonemes" in dac["metadata"]:
|
||||
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
|
||||
if "language" in dac["metadata"]:
|
||||
utterance_metadata["language"] = dac["metadata"]["language"]
|
||||
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
||||
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
||||
|
||||
for k, v in utterance_metadata.items():
|
||||
metadata[id][k] = v
|
||||
|
||||
except Exception as e:
|
||||
tqdm.write(f'Error while processing {id}: {e}')
|
||||
|
||||
if wrote:
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||
add( data_dir, type="training" )
|
||||
|
||||
# validation
|
||||
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
|
||||
add( data_dir, type="validation" )
|
||||
|
||||
# noise
|
||||
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
|
||||
add( data_dir, type="noise", texts=False )
|
||||
|
||||
# parse yaml to create an hdf5 file
|
||||
def create_dataset_hdf5( skip_existing=True ):
|
||||
cfg.dataset.use_hdf5 = True
|
||||
cfg.load_hdf5(write=True)
|
||||
hf = cfg.hdf5
|
||||
|
||||
symmap = get_symmap()
|
||||
|
||||
root = str(cfg.data_dir)
|
||||
metadata_root = str(cfg.metadata_dir)
|
||||
|
||||
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
name = str(dir)
|
||||
name = name.replace(root, "")
|
||||
|
||||
# yucky
|
||||
speaker_name = name
|
||||
if "LibriTTS-R" in speaker_name:
|
||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
||||
|
||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
||||
|
||||
if not os.path.isdir(f'{root}/{name}/'):
|
||||
return
|
||||
|
||||
files = os.listdir(f'{root}/{name}/')
|
||||
|
||||
# grab IDs for every file
|
||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||
|
||||
if not quant_exists:
|
||||
continue
|
||||
|
||||
key = f'{type}/{speaker_name}/{id}'
|
||||
|
||||
if skip_existing and key in hf:
|
||||
continue
|
||||
|
||||
group = hf.create_group(key) if key not in hf else hf[key]
|
||||
|
||||
if id not in metadata:
|
||||
metadata[id] = {}
|
||||
|
||||
utterance_metadata = {}
|
||||
|
||||
# audio
|
||||
if audios:
|
||||
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
|
||||
if "text" in dac["metadata"]:
|
||||
utterance_metadata["text"] = dac["metadata"]["text"]
|
||||
if "phonemes" in dac["metadata"]:
|
||||
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
|
||||
if "language" in dac["metadata"]:
|
||||
utterance_metadata["language"] = dac["metadata"]["language"]
|
||||
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
||||
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
||||
|
||||
if "audio" not in group:
|
||||
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
|
||||
|
||||
# text
|
||||
if texts:
|
||||
if not utterance_metadata and text_exists:
|
||||
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||
|
||||
phn = "".join(utterance_metadata["phonemes"])
|
||||
phn = cfg.tokenizer.encode(phn)
|
||||
phn = np.array(phn).astype(np.uint8)
|
||||
|
||||
if "text" not in group:
|
||||
group.create_dataset('text', data=phn, compression='lzf')
|
||||
|
||||
for k, v in utterance_metadata.items():
|
||||
group.attrs[k] = v
|
||||
metadata[id][k] = v
|
||||
|
||||
except Exception as e:
|
||||
tqdm.write(f'Error while processing {id}: {e}')
|
||||
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||
add( data_dir, type="training" )
|
||||
|
||||
# validation
|
||||
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
|
||||
add( data_dir, type="validation" )
|
||||
|
||||
# noise
|
||||
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
||||
add( data_dir, type="noise", texts=False )
|
||||
|
||||
# write symmap
|
||||
if "symmap" in hf:
|
||||
del hf['symmap']
|
||||
|
||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||
hf.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--action", type=str)
|
||||
parser.add_argument("--tasks", type=str)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
task = args.action
|
||||
|
||||
cfg.dataset.workers = 1
|
||||
|
||||
if args.action == "hdf5":
|
||||
create_dataset_hdf5()
|
||||
elif args.action == "list-dataset":
|
||||
dataset = []
|
||||
for group in os.listdir(cfg.data_dir):
|
||||
for name in os.listdir(cfg.data_dir / group):
|
||||
if len(os.listdir(cfg.data_dir / group / name)) == 0:
|
||||
continue
|
||||
dataset.append(f'{group}/{name}')
|
||||
|
||||
_logger.info(json.dumps(dataset))
|
||||
elif args.action == "metadata":
|
||||
create_dataset_metadata()
|
||||
elif args.action == "sample":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
samples = {
|
||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
}
|
||||
|
||||
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
|
||||
try:
|
||||
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
_logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
|
||||
try:
|
||||
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
|
||||
v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||
v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
_logger.info(f'{k}[{i}]: {v[i]}')
|
||||
elif args.action == "validate":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
missing = set()
|
||||
|
||||
for i in range(len( train_dl.dataset )):
|
||||
batch = train_dl.dataset[i]
|
||||
|
||||
text = batch['text']
|
||||
phonemes = batch['metadata']['phonemes']
|
||||
|
||||
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
|
||||
for i, token in enumerate(decoded):
|
||||
if token != "<unk>":
|
||||
continue
|
||||
|
||||
phone = phonemes[i]
|
||||
|
||||
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
|
||||
|
||||
missing |= set([phone])
|
||||
|
||||
_logger.info( f"Missing tokens: {missing}" )
|
||||
|
||||
|
||||
elif args.action == "tasks":
|
||||
index = 0
|
||||
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
sample = train_dl.dataset[0]
|
||||
print(sample)
|
||||
batch = next(iter(train_dl))
|
||||
|
||||
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
|
||||
if task not in cfg.dataset.tasks_list:
|
||||
continue
|
||||
|
||||
_logger.info( f'{text} {task} {cfg.model.resp_levels}')
|
||||
_logger.info( f'{proms.shape} {resps.shape}' )
|
||||
|
||||
tokens = 0
|
||||
tokens += sum([ text.shape[0] for text in batch["text"] ])
|
||||
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
_logger.info( f'{tokens}' )
|
||||
|
||||
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
||||
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
||||
break
|
||||
"""
|
|
@ -1,6 +1,6 @@
|
|||
from ..config import cfg
|
||||
|
||||
from ..utils.distributed import fix_unset_envs
|
||||
from ..utils.distributed import fix_unset_envs, ddp_model
|
||||
fix_unset_envs()
|
||||
|
||||
if cfg.trainer.backend == "deepspeed":
|
||||
|
@ -8,4 +8,211 @@ if cfg.trainer.backend == "deepspeed":
|
|||
elif cfg.trainer.backend == "local":
|
||||
from .base import Engine
|
||||
|
||||
from .base import Engines, TrainFeeder, default_feeder
|
||||
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||
|
||||
from ..models import get_models, get_model
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils.io import torch_save, torch_load, pick_path
|
||||
from ..models.lora import apply_lora, lora_load_state_dict
|
||||
|
||||
import torch
|
||||
import re
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
deepspeed_available = False
|
||||
try:
|
||||
import deepspeed
|
||||
deepspeed_available = True
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
from functools import cache
|
||||
|
||||
@cache
|
||||
def load_engines(training=True, **model_kwargs):
|
||||
models = get_models(cfg.models, training=training, **model_kwargs)
|
||||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
state = None
|
||||
stats = None
|
||||
lora = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
|
||||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = pick_path( cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
|
||||
|
||||
# to handle the issue of training with deepspeed, but inferencing with local
|
||||
if checkpoint_path.exists() and backend == "local":
|
||||
tag = open(checkpoint_path).read()
|
||||
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
|
||||
loads_state_dict = True
|
||||
|
||||
# load state early
|
||||
if loads_state_dict:
|
||||
state = torch_load(load_path, device=cfg.device)
|
||||
|
||||
# check if config is defined in state, and re-initialize the model
|
||||
if "config" in state and False:
|
||||
_logger.warning("Model config definition in weights, re-loading...")
|
||||
config_state = state["config"]
|
||||
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
||||
|
||||
hyper_config = model.config
|
||||
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||
ddp = cfg.trainer.ddp
|
||||
|
||||
engine_class = LocalEngine if backend == "local" else Engine
|
||||
|
||||
# apply model replacers
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||
model.model = ml.replace_embedding( model.model )
|
||||
|
||||
for lora in cfg.loras:
|
||||
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
||||
|
||||
if inferencing:
|
||||
model.config.training = False
|
||||
|
||||
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
||||
optimizer_class = None
|
||||
scheduler_class = None
|
||||
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
params["betas"] = (0.9, 0.96)
|
||||
params["eps"] = 1e-07
|
||||
params["weight_decay"] = 0.01
|
||||
|
||||
# for dadaptation since it has Adam only
|
||||
if ml.AdamW == ml.Adam:
|
||||
params["decouple"] = True
|
||||
|
||||
optimizer_class = ml.AdamW
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
optimizer = ml.SGD
|
||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||
optimizer_class = ml.Prodigy
|
||||
|
||||
params['d_coef'] = params['lr']
|
||||
params['lr'] = 1.0
|
||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||
optimizer_class = ml.Adagrad
|
||||
else:
|
||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
|
||||
optimizer = optimizer_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
|
||||
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
scheduler_class = ml.schedulefree.AdamWScheduleFree
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
scheduler_class = ml.schedulefree.SGDScheduleFree
|
||||
else:
|
||||
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
optimizer = scheduler_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||
lr = params['lr'],
|
||||
warmup_steps = cfg.hyperparameters.warmup_steps
|
||||
)
|
||||
|
||||
"""
|
||||
# set up our LR scheduler here
|
||||
"""
|
||||
|
||||
if inferencing:
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
# load state dict if requested / required
|
||||
if loads_state_dict:
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
if "stats" in state:
|
||||
stats = state["stats"]
|
||||
|
||||
# do not load stats if we're training a LoRA
|
||||
if cfg.lora is not None or cfg.trainer.restart_step_count:
|
||||
stats = None
|
||||
|
||||
if "module" in state:
|
||||
state = state["module"]
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
if lora_path.exists():
|
||||
_logger.info( f"Loaded LoRA state dict: {lora_path}" )
|
||||
|
||||
state = torch_load(lora_path, device=cfg.device)
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
lora_load_state_dict( model, state )
|
||||
|
||||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
model = ddp_model(model)
|
||||
# wrap optimization class
|
||||
elif cfg.optimizations.compile:
|
||||
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||
# deepspeed inferencing
|
||||
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||
engine_class = LocalEngine
|
||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||
|
||||
# use base engine if requested
|
||||
engines[name] = engine_class(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
|
||||
hyper_config=hyper_config,
|
||||
stats=stats
|
||||
)
|
||||
|
||||
|
||||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
||||
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
|
||||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint(training=not inferencing)
|
||||
|
||||
# freeze requested params
|
||||
for name, engine in engines.items():
|
||||
engine.freeze(freeze_all=False)
|
||||
|
||||
# split models over requested devices
|
||||
if cfg.optimizations.model_offloading:
|
||||
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
|
||||
|
||||
return engines
|
||||
|
|
|
@ -28,7 +28,9 @@ def default_feeder(engine, batch):
|
|||
|
||||
from ..config import cfg
|
||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
from ..utils.distributed import init_distributed, distributed_initialized
|
||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
|
||||
from ..utils.io import torch_save, torch_load
|
||||
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
@ -39,40 +41,65 @@ import os
|
|||
from torch import Tensor
|
||||
from torch.distributed import all_reduce
|
||||
from typing import Any, Protocol
|
||||
from functools import cached_property
|
||||
|
||||
from .base import TrainFeeder
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
if not distributed_initialized() and cfg.trainer.backend == "local":
|
||||
def _nop():
|
||||
...
|
||||
fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group
|
||||
init_distributed(fn)
|
||||
if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
|
||||
init_distributed(torch.distributed.init_process_group)
|
||||
|
||||
# A very naive engine implementation using barebones PyTorch
|
||||
# to-do: implement lr_sheduling
|
||||
class Engine():
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
||||
if 'hyper_config' in kwargs:
|
||||
self.hyper_config = kwargs['hyper_config']
|
||||
kwargs.pop("hyper_config")
|
||||
|
||||
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
|
||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
||||
self.global_steps = 0
|
||||
self.micro_steps = 0
|
||||
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
|
||||
self.global_steps = kwargs.pop("global_steps", 0)
|
||||
self.micro_steps = kwargs.pop("micro_steps", 0)
|
||||
self.global_samples = kwargs.pop("global_samples", 0)
|
||||
self.tokens_processed = kwargs.pop("tokens_processed", 0)
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.requires_grad_(False)
|
||||
self._frozen_params.add(p)
|
||||
self._frozen_params = set()
|
||||
|
||||
self.max_nan_losses = 8
|
||||
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||
|
||||
self.current_batch_size = 0
|
||||
self._global_grad_norm = None
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
# set to freeze
|
||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||
param.requires_grad_(False)
|
||||
self._frozen_params.add(param)
|
||||
|
||||
def unfreeze(self):
|
||||
for p in self._frozen_params:
|
||||
p.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def _training(self):
|
||||
if not hasattr(self, "hyper_config"):
|
||||
return True
|
||||
return self.hyper_config.training
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
@ -81,8 +108,17 @@ class Engine():
|
|||
def micro_step(self):
|
||||
return self.micro_steps
|
||||
|
||||
def train_batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
||||
|
||||
@property
|
||||
def gradient_accumulation_steps(self):
|
||||
return cfg.hyperparameters.gradient_accumulation_steps
|
||||
|
||||
@property
|
||||
def gradient_clipping(self):
|
||||
return cfg.hyperparameters.gradient_clipping
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
@ -91,42 +127,74 @@ class Engine():
|
|||
return dispatch_attribute(self.module, *args, **kwargs)
|
||||
|
||||
def save_checkpoint(self, save_dir, tag ):
|
||||
save_path = save_dir / tag / "state.pth"
|
||||
if is_global_leader():
|
||||
module = self.module.state_dict()
|
||||
|
||||
# if training lora
|
||||
# this is a separate path to override saving the weights
|
||||
lora = None
|
||||
if cfg.lora is not None:
|
||||
lora, module = lora_get_state_dict( module, split = True )
|
||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
save_path = save_dir / tag / f"state.{cfg.weights_format}"
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save({
|
||||
"global_step": self.global_step,
|
||||
"micro_step": self.micro_step,
|
||||
"module": self.module.state_dict(),
|
||||
|
||||
torch_save({
|
||||
"module": module,
|
||||
"lora": lora,
|
||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||
|
||||
"stats": {
|
||||
"global_step": self.global_step,
|
||||
"micro_step": self.micro_step,
|
||||
"global_samples": self.global_samples,
|
||||
"tokens_processed": self.tokens_processed,
|
||||
}
|
||||
}, save_path)
|
||||
|
||||
open(save_dir / "latest", 'w').write( tag )
|
||||
|
||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
||||
torch.distributed.barrier()
|
||||
|
||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
|
||||
# override to load the lora instead
|
||||
if cfg.lora is not None:
|
||||
load_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
if tag is None:
|
||||
tag_path = load_dir / "latest"
|
||||
|
||||
if not tag_path.exists():
|
||||
return
|
||||
|
||||
tag = open(tag_path).read()
|
||||
|
||||
load_path = load_dir / tag / "state.pth"
|
||||
load_path = load_dir / tag / f"state.{cfg.weights_format}"
|
||||
|
||||
if not load_path.exists():
|
||||
return
|
||||
|
||||
state = torch.load(load_path)
|
||||
self.global_steps = state['global_step']
|
||||
self.micro_steps = state['micro_step']
|
||||
self.module.load_state_dict(state['module'])
|
||||
state = torch_load(load_path, device=cfg.device)
|
||||
|
||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
||||
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
|
||||
|
||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||
|
||||
if load_optimizer_states:
|
||||
self.optimizer.load_state_dict(state['optimizer'])
|
||||
self.optimizer.load_state_dict(state['optimizer']) #, device=cfg.device)
|
||||
|
||||
if load_lr_scheduler_states:
|
||||
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
|
||||
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device)
|
||||
|
||||
if 'lora' in state:
|
||||
lora_load_state_dict( self.module, state['lora'] )
|
||||
|
||||
def eval(self):
|
||||
return self.module.eval()
|
||||
|
@ -136,46 +204,80 @@ class Engine():
|
|||
|
||||
def to(self, *args, **kwargs):
|
||||
self.module = self.module.to(*args, **kwargs)
|
||||
return self.module
|
||||
if self.optimizer:
|
||||
self.optimizer = self.optimizer.to(*args, **kwargs)
|
||||
|
||||
return self
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
@cached_property
|
||||
def device(self):
|
||||
return next(self.module.parameters()).device
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module.forward(*args, **kwargs)
|
||||
|
||||
def backward(self, loss):
|
||||
if self.loss_scaler is not None:
|
||||
return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward()
|
||||
return (loss / self.gradient_accumulation_steps).backward()
|
||||
|
||||
|
||||
def step(self):
|
||||
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
|
||||
self.micro_steps += 1
|
||||
self.global_samples += self.batch_size
|
||||
|
||||
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
|
||||
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
|
||||
|
||||
self.global_steps += 1
|
||||
if self.loss_scaler is not None:
|
||||
self.loss_scaler.step(self.optimizer)
|
||||
self.loss_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
self._get_grad_norm()
|
||||
|
||||
def _get_grad_norm(self):
|
||||
t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
|
||||
self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if 'lr' in param_group:
|
||||
if 'd_coeff' in param_group:
|
||||
lrs.append(param_group['d_coeff'])
|
||||
elif 'lr' in param_group:
|
||||
lrs.append(param_group['lr'])
|
||||
return lrs
|
||||
|
||||
def set_lr(self, lr):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if 'lr' in param_group:
|
||||
if 'd_coeff' in param_group:
|
||||
param_group['d_coeff'] = lr
|
||||
elif 'lr' in param_group:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def get_global_grad_norm(self):
|
||||
return 0.0
|
||||
return self._global_grad_norm
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
with ml.autocast():
|
||||
self.forward(*args, **kwargs)
|
||||
|
||||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
@ -194,6 +296,8 @@ class Engines(dict[str, Engine]):
|
|||
def setup(self):
|
||||
self._global_step = 0
|
||||
self._micro_step = 0
|
||||
self._batch_size = 0
|
||||
self._global_samples = 0
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
|
@ -203,6 +307,14 @@ class Engines(dict[str, Engine]):
|
|||
def micro_step(self):
|
||||
return self._micro_step
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def global_samples(self):
|
||||
return self._global_samples
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
ret = {}
|
||||
for engine in self.values():
|
||||
|
@ -213,6 +325,50 @@ class Engines(dict[str, Engine]):
|
|||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def export(self, userdata={}, callback=None, dtype=None, format=None):
|
||||
if not format:
|
||||
format = cfg.weights_format
|
||||
format = format.lower()
|
||||
|
||||
if dtype is None:
|
||||
dtype = cfg.trainer.dtype
|
||||
|
||||
for name, engine in self.items():
|
||||
module = engine.module.state_dict()
|
||||
lora = None
|
||||
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
|
||||
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
||||
|
||||
# safety
|
||||
for k, v in module.items():
|
||||
module[k] = v.to(dtype)
|
||||
|
||||
if cfg.lora is not None:
|
||||
lora, module = lora_get_state_dict( module, split = True )
|
||||
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
|
||||
|
||||
state_dict = {
|
||||
'module': module,
|
||||
'lora': lora,
|
||||
"stats": {
|
||||
"global_step": engine.global_step,
|
||||
"micro_step": engine.micro_step,
|
||||
"global_samples": engine.global_samples,
|
||||
"tokens_processed": engine.tokens_processed,
|
||||
},
|
||||
"userdata": userdata,
|
||||
"config": config.__dict__
|
||||
}
|
||||
|
||||
if lora is None:
|
||||
del state_dict['lora']
|
||||
|
||||
if callback:
|
||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
||||
torch_save(state_dict, save_path)
|
||||
_logger.info(f"Exported {name} to {save_path}")
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = cfg.trainer.save_tag
|
||||
|
@ -222,47 +378,67 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
for name, engine in self.items():
|
||||
engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag)
|
||||
if not engine._training:
|
||||
continue
|
||||
|
||||
def load_checkpoint(self, tag=None):
|
||||
save_dir = cfg.ckpt_dir / name
|
||||
try:
|
||||
engine.save_checkpoint(save_dir, tag=tag)
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
|
||||
|
||||
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
|
||||
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
|
||||
checkpoints = [ d for d in list(save_dir.glob("*")) if d.is_dir() ]
|
||||
checkpoints.sort(key=lambda x: x.stat().st_mtime)
|
||||
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
|
||||
for d in checkpoints:
|
||||
if not d.is_dir() or not d.exists():
|
||||
continue
|
||||
_logger.info(f"Removing {d}")
|
||||
for p in d.iterdir():
|
||||
p.unlink()
|
||||
d.rmdir()
|
||||
|
||||
def load_checkpoint(self, tag=None, training=True):
|
||||
if not tag:
|
||||
tag = cfg.trainer.load_tag
|
||||
|
||||
for name, engine in self.items():
|
||||
load_dir = cfg.ckpt_dir / name
|
||||
|
||||
engine.load_checkpoint(
|
||||
tag=tag,
|
||||
load_dir=load_dir,
|
||||
load_module_strict=cfg.trainer.strict_loading,
|
||||
load_optimizer_states=cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=cfg.trainer.load_states,
|
||||
load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states,
|
||||
load_module_only=cfg.trainer.load_module_only,
|
||||
)
|
||||
if cfg.trainer.restart_step_count:
|
||||
engine.global_steps = 0
|
||||
engine.mocro_step = 0
|
||||
engine.global_samples = 0
|
||||
engine.tokens_processed = 0
|
||||
|
||||
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
|
||||
if cfg.hyperparameters.scheduler_type == "":
|
||||
self.set_lr(cfg.hyperparameters.learning_rate)
|
||||
|
||||
self._update_global_step()
|
||||
self._update_micro_step()
|
||||
self._update()
|
||||
|
||||
def set_lr(self, lr):
|
||||
for engine in self.values():
|
||||
if not engine._training:
|
||||
continue
|
||||
engine.set_lr(lr)
|
||||
|
||||
def _update_global_step(self):
|
||||
def _update(self):
|
||||
for engine in self.values():
|
||||
self._global_step = max(self._global_step, engine.global_step)
|
||||
|
||||
def _update_micro_step(self):
|
||||
for engine in self.values():
|
||||
self._micro_step = max(self._micro_step, engine.micro_step)
|
||||
|
||||
def train_batch_size(self):
|
||||
batch_size = 0
|
||||
for engine in self.values():
|
||||
batch_size = max(batch_size, engine.train_batch_size())
|
||||
self._batch_size = max(self._batch_size, engine.batch_size)
|
||||
self._global_samples = max(self._global_samples, engine.global_samples)
|
||||
|
||||
def eval(self):
|
||||
for engine in self.values():
|
||||
|
@ -279,7 +455,10 @@ class Engines(dict[str, Engine]):
|
|||
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
||||
return stats
|
||||
|
||||
def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()):
|
||||
def quit(self):
|
||||
cleanup_distributed()
|
||||
|
||||
def step(self, batch, feeder: TrainFeeder = default_feeder):
|
||||
total_elapsed_time = 0
|
||||
|
||||
stats: Any = dict()
|
||||
|
@ -287,10 +466,11 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.trainer.gc_mode == 'step':
|
||||
do_gc()
|
||||
|
||||
batch = to_device(batch, device)
|
||||
|
||||
for name, engine in self.items():
|
||||
#torch.cuda.synchronize()
|
||||
if not engine._training:
|
||||
continue
|
||||
|
||||
device = engine.device
|
||||
|
||||
if cfg.trainer.gc_mode == 'substep':
|
||||
do_gc()
|
||||
|
@ -298,9 +478,8 @@ class Engines(dict[str, Engine]):
|
|||
start_time = time.time()
|
||||
|
||||
tries = 4
|
||||
n_ooms = torch.zeros([], device=cfg.device)
|
||||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
if cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, device)
|
||||
|
||||
if not cfg.trainer.check_for_oom:
|
||||
|
@ -311,7 +490,7 @@ class Engines(dict[str, Engine]):
|
|||
res = feeder( engine=engine, batch=batch )
|
||||
break
|
||||
except RuntimeError as e:
|
||||
print("Forward", str(e))
|
||||
_logger.error(f"Forward: {str(e)}")
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
|
@ -329,6 +508,7 @@ class Engines(dict[str, Engine]):
|
|||
do_gc()
|
||||
continue
|
||||
|
||||
if world_size() > 1:
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
|
@ -340,7 +520,7 @@ class Engines(dict[str, Engine]):
|
|||
loss, engine_stats = res
|
||||
engine_stats |= self.gather_attribute("scalar")
|
||||
|
||||
n_ooms = torch.zeros([], device=cfg.device)
|
||||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
if cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, 'cpu')
|
||||
|
@ -348,10 +528,11 @@ class Engines(dict[str, Engine]):
|
|||
if not cfg.trainer.check_for_oom:
|
||||
engine.backward(loss)
|
||||
else:
|
||||
# to-do: properly handle when one GPU throws an OOM because it just halts
|
||||
try:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
print("Backwards:", str(e))
|
||||
_logger.error(f"Backwards: {str(e)}")
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
|
@ -359,9 +540,12 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
n_ooms += 1
|
||||
|
||||
if world_size() > 1:
|
||||
all_reduce(n_ooms)
|
||||
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
|
||||
engine.step()
|
||||
|
@ -370,27 +554,36 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
elapsed_time = time.time() - start_time
|
||||
total_elapsed_time += elapsed_time
|
||||
grad_norm = engine.get_global_grad_norm()
|
||||
loss_scale = 1
|
||||
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
|
||||
loss_scale = engine.optimizer.loss_scale
|
||||
|
||||
if grad_norm is not None:
|
||||
grad_norm /= loss_scale
|
||||
|
||||
stats.update(
|
||||
flatten_dict(
|
||||
{
|
||||
name.split("-")[0]: dict(
|
||||
loss=loss.item(),
|
||||
**engine_stats,
|
||||
lr=engine.get_lr()[0],
|
||||
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
||||
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
|
||||
loss_scale=loss_scale if loss_scale != 1 else None,
|
||||
elapsed_time=elapsed_time,
|
||||
engine_step=engine.global_step,
|
||||
**engine_stats,
|
||||
samples_processed=engine.global_samples,
|
||||
tokens_processed=engine.tokens_processed,
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
self._update_global_step()
|
||||
self._update_micro_step()
|
||||
stats["batch_size"] = self.train_batch_size() # len(batch["text"])
|
||||
self._update()
|
||||
|
||||
if len(self.keys()) > 1:
|
||||
stats["elapsed_time"] = total_elapsed_time
|
||||
stats["wall_time"] = time.time()
|
||||
stats["global_step"] = self.global_step
|
||||
|
||||
stats["it"] = self.global_step
|
||||
|
||||
return stats
|
||||
|
|
|
@ -25,29 +25,72 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr
|
|||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from ..utils.distributed import init_distributed, distributed_initialized
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
from ..models.lora import freeze_non_lora_weights
|
||||
|
||||
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
||||
init_distributed(init_deepspeed_dist)
|
||||
|
||||
class Engine(DeepSpeedEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
|
||||
self.hyper_config = None
|
||||
if 'hyper_config' in kwargs:
|
||||
self.hyper_config = kwargs['hyper_config']
|
||||
kwargs.pop("hyper_config")
|
||||
|
||||
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
|
||||
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
||||
|
||||
stats = {
|
||||
"global_step": 0,
|
||||
"micro_step": 0,
|
||||
"global_samples": 0,
|
||||
"tokens_processed": 0,
|
||||
}
|
||||
|
||||
# kwargs['stats'] = None will return None when popped
|
||||
maybe_stats = kwargs.pop('stats', stats)
|
||||
if maybe_stats is not None:
|
||||
stats = maybe_stats
|
||||
|
||||
super().__init__(None, *args, **kwargs)
|
||||
self._frozen_params = set()
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.requires_grad_(False)
|
||||
self._frozen_params.add(p)
|
||||
self.global_steps = stats["global_step"]
|
||||
self.micro_steps = stats["micro_step"]
|
||||
self.global_samples = stats["global_samples"]
|
||||
self.tokens_processed = stats["tokens_processed"]
|
||||
|
||||
self.max_nan_losses = 8
|
||||
self.current_batch_size = 0
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
||||
for param in frozen_params:
|
||||
self._frozen_params.add( param )
|
||||
|
||||
return
|
||||
|
||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||
param.requires_grad_(False)
|
||||
self._frozen_params.add(param)
|
||||
|
||||
def unfreeze(self):
|
||||
for p in self._frozen_params:
|
||||
p.requires_grad_(True)
|
||||
for param in self._frozen_params:
|
||||
param.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def _training(self):
|
||||
return self.hyper_config.training
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
@ -56,6 +99,10 @@ class Engine(DeepSpeedEngine):
|
|||
def micro_step(self):
|
||||
return self.micro_steps
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
||||
|
@ -66,17 +113,40 @@ class Engine(DeepSpeedEngine):
|
|||
try:
|
||||
if hasattr(self.optimizer, 'param_groups'):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
|
||||
else:
|
||||
self.optimizer.set_lr(lr)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
_logger.warning(str(e))
|
||||
|
||||
# we'll just have to live with the LoRA weights living within our main weights
|
||||
# they're easy to extract anyways
|
||||
def load_checkpoint(self, load_dir, **kwargs ):
|
||||
# override to load the lora instead
|
||||
if cfg.lora is not None:
|
||||
load_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
return super().load_checkpoint( load_dir, **kwargs )
|
||||
|
||||
def save_checkpoint(self, save_dir, **kwargs ):
|
||||
# override to save the lora instead
|
||||
if cfg.lora is not None:
|
||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
return super().save_checkpoint( save_dir, **kwargs )
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
with ml.autocast():
|
||||
self.forward(*args, **kwargs)
|
||||
|
||||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
|
|
@ -1,31 +1,67 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
|
||||
from .data import get_symmap
|
||||
from .train import load_engines
|
||||
from .engines import load_engines
|
||||
from .config import cfg
|
||||
from .models.lora import lora_get_state_dict
|
||||
from .utils.io import torch_save, torch_load
|
||||
|
||||
def load_models():
|
||||
models = {}
|
||||
engines = load_engines()
|
||||
for name in engines:
|
||||
model = engines[name].module.cpu()
|
||||
models[name] = model
|
||||
# yanks a LoRA from the training checkpoint
|
||||
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||
if dtype is None:
|
||||
dtype = cfg.inference.dtype
|
||||
|
||||
return models
|
||||
format = save_path.stem[1:]
|
||||
|
||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||
# should always be included, but just in case
|
||||
if lora is None and "module" in state_dict:
|
||||
lora, module = lora_get_state_dict( state_dict["module"], split = True )
|
||||
state_dict["module"] = module
|
||||
|
||||
if "lora" in state_dict:
|
||||
state_dict["lora"] = None
|
||||
|
||||
# should raise an exception since there's nothing to extract, or at least a warning
|
||||
if not lora:
|
||||
return state_dict
|
||||
|
||||
# save lora specifically
|
||||
# should probably export other attributes, similar to what SD LoRAs do
|
||||
save_path = save_path.parent / f"lora.{format}"
|
||||
torch_save( {
|
||||
"module": lora,
|
||||
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
|
||||
}, save_path )
|
||||
|
||||
return state_dict
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("path")
|
||||
args = parser.parse_args()
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
|
||||
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
models = load_models()
|
||||
for name in models:
|
||||
model = models[name]
|
||||
if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]:
|
||||
raise Exception(f"Unknown requested format: {args.format}")
|
||||
|
||||
outpath = f'{args.path}/{name}.pt'
|
||||
torch.save(model, outpath)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
if args.module_only:
|
||||
cfg.trainer.load_module_only = True
|
||||
|
||||
if args.dtype != "auto":
|
||||
cfg.trainer.weight_dtype = args.dtype
|
||||
|
||||
# necessary to ensure we are actually exporting the weights right
|
||||
cfg.inference.backend = cfg.trainer.backend
|
||||
|
||||
engines = load_engines(training=False) # to ignore loading optimizer state
|
||||
|
||||
callback = None
|
||||
engines.export(userdata={"symmap": get_symmap()}, callback=callback, format=args.format)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,53 +1,103 @@
|
|||
import torch
|
||||
import torchaudio
|
||||
import soundfile
|
||||
import time
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from torch import Tensor
|
||||
from einops import rearrange
|
||||
from pathlib import Path
|
||||
|
||||
from .utils import to_device, set_seed, wrapper as ml
|
||||
from PIL import Image, ImageDraw
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from .config import cfg
|
||||
from .export import load_models
|
||||
from .data import get_symmap, _get_symbols
|
||||
from .config import cfg, Config
|
||||
from .models import get_models
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_symmap, tokenize
|
||||
|
||||
if deepspeed_available:
|
||||
import deepspeed
|
||||
|
||||
class Classifier():
|
||||
def __init__( self, width=300, height=80, config=None, ckpt=None, device=cfg.get_device(), dtype="float32" ):
|
||||
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
self.loading = True
|
||||
|
||||
# yes I can just grab **kwargs and forward them here
|
||||
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
|
||||
self.load_model()
|
||||
|
||||
self.loading = False
|
||||
|
||||
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
if config:
|
||||
_logger.info(f"Loading YAML: {config}")
|
||||
cfg.load_yaml( config )
|
||||
|
||||
self.loading = True
|
||||
try:
|
||||
cfg.format( training=False )
|
||||
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
||||
except Exception as e:
|
||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||
|
||||
if amp is None:
|
||||
amp = cfg.inference.amp
|
||||
if dtype is None or dtype == "auto":
|
||||
dtype = cfg.inference.weight_dtype
|
||||
if device is None:
|
||||
device = cfg.device
|
||||
|
||||
cfg.device = device
|
||||
cfg.mode = "inferencing"
|
||||
cfg.trainer.backend = cfg.inference.backend
|
||||
cfg.trainer.weight_dtype = dtype
|
||||
cfg.inference.weight_dtype = dtype
|
||||
|
||||
self.device = device
|
||||
self.dtype = cfg.inference.dtype
|
||||
self.amp = amp
|
||||
|
||||
if ckpt:
|
||||
self.load_model_from_ckpt( ckpt )
|
||||
else:
|
||||
self.load_model_from_cfg( config )
|
||||
self.model_kwargs = {}
|
||||
|
||||
self.model.eval()
|
||||
def load_model( self ):
|
||||
load_engines.cache_clear()
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.engines = load_engines(training=False, **self.model_kwargs)
|
||||
for name, engine in self.engines.items():
|
||||
if self.dtype != torch.int8:
|
||||
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
self.engines.eval()
|
||||
self.symmap = get_symmap()
|
||||
|
||||
self.width = 300
|
||||
self.height = 80
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((self.height, self.width)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
self.loading = False
|
||||
_logger.info("Loaded model")
|
||||
|
||||
def load_model_from_ckpt( self, ckpt ):
|
||||
self.ckpt = ckpt
|
||||
@torch.inference_mode()
|
||||
def inference( self, image, temperature=1.0 ):
|
||||
model = None
|
||||
|
||||
self.model = torch.load(self.ckpt).to(self.device)
|
||||
|
||||
def load_model_from_cfg( self, config_path ):
|
||||
|
||||
models = load_models()
|
||||
for name in models:
|
||||
model = models[name]
|
||||
self.model = model.to(self.device)
|
||||
for name, engine in self.engines.items():
|
||||
model = engine.module
|
||||
break
|
||||
|
||||
def inference( self, image, temperature=1.0 ):
|
||||
image = self.transform(image).to(self.device)
|
||||
image = self.transform(image).to(self.device).to(self.dtype)
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
answer = model( image=[image], sampling_temperature=temperature )
|
||||
|
||||
answer = [ "".join(answer) ]
|
||||
|
||||
answer = self.model( image=[image], sampling_temperature=temperature )
|
||||
answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
|
||||
|
||||
return answer
|
9
image_classifier/models/__init__.py
Executable file → Normal file
9
image_classifier/models/__init__.py
Executable file → Normal file
|
@ -1,18 +1,19 @@
|
|||
from .base import Model
|
||||
|
||||
def get_model(cfg):
|
||||
def get_model(cfg, training=False):
|
||||
name = cfg.name
|
||||
|
||||
model = Model(
|
||||
n_tokens=cfg.tokens,
|
||||
n_len=cfg.len,
|
||||
d_model=cfg.dim,
|
||||
d_resnet=cfg.resnet,
|
||||
)
|
||||
model._cfg = cfg
|
||||
model.config = cfg
|
||||
|
||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
return model
|
||||
|
||||
def get_models(models):
|
||||
return { model.full_name: get_model(model) for model in models }
|
||||
def get_models(models, training=False):
|
||||
return { model.full_name: get_model(model, training=training) for model in models }
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch.distributions import Categorical
|
|||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
|
||||
|
||||
from ..data import get_symmap
|
||||
|
||||
|
@ -20,12 +20,12 @@ class Model(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
n_tokens: int = 0, # number of token types
|
||||
n_len: int = 6, # how long a sequence can be
|
||||
n_len: int = 12, # how long a sequence can be
|
||||
d_model: int = 512,
|
||||
d_resnet: int = 18,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
_symmap = get_symmap()
|
||||
self.symmap = { f'{v}': k for k, v in _symmap.items() }
|
||||
self.symmap['0'] = ""
|
||||
|
@ -36,8 +36,26 @@ class Model(nn.Module):
|
|||
self.n_tokens = n_tokens
|
||||
self.n_len = n_len + 2 # start/stop tokens
|
||||
self.d_model = d_model
|
||||
self.d_resnet = d_resnet
|
||||
|
||||
self.resnet = resnet18(pretrained=False)
|
||||
ResNet = resnet18
|
||||
if d_resnet == 18:
|
||||
print("Using resnet18")
|
||||
ResNet = resnet18
|
||||
elif d_resnet == 34:
|
||||
print("Using resnet34")
|
||||
ResNet = resnet34
|
||||
elif d_resnet == 50:
|
||||
print("Using resnet50")
|
||||
ResNet = resnet50
|
||||
elif d_resnet == 101:
|
||||
print("Using resnet101")
|
||||
ResNet = resnet101
|
||||
elif d_resnet == 152:
|
||||
print("Using resnet152")
|
||||
ResNet = resnet152
|
||||
|
||||
self.resnet = ResNet(pretrained=False)
|
||||
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
|
@ -61,33 +79,29 @@ class Model(nn.Module):
|
|||
|
||||
sampling_temperature: float = 1.0,
|
||||
):
|
||||
x_list = torch.stack( image, dim=0 )
|
||||
logits = self.resnet( torch.stack( image, dim=0 ) )
|
||||
logits = logits.view(logits.size(0), self.n_len, self.n_tokens).permute(1, 0, 2)
|
||||
|
||||
x = self.resnet( x_list )
|
||||
y = x.view(x.size(0), self.n_len, self.n_tokens)
|
||||
|
||||
# either of these should do, but my VALL-E forward pass uses this, so might as well keep to it
|
||||
# pred = y.argmax(dim=2)
|
||||
pred = Categorical(logits=y / sampling_temperature).sample()
|
||||
|
||||
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
|
||||
pred = logits.argmax(dim=2)
|
||||
|
||||
if text is not None:
|
||||
y_list = rearrange(pad_sequence(text), "t b -> b t")
|
||||
|
||||
loss = 0
|
||||
labels = rearrange(pad_sequence(text), "t b -> b t").permute(1, 0)
|
||||
loss = []
|
||||
for i in range(self.n_len):
|
||||
if i >= y_list.shape[1]:
|
||||
if i >= labels.shape[0]:
|
||||
break
|
||||
loss += F.cross_entropy( y[:, i], y_list[:, i] )
|
||||
loss.append( F.cross_entropy(logits[i], labels[i]) )
|
||||
|
||||
self.loss = dict(
|
||||
nll=loss
|
||||
nll = sum( loss ) / len( loss ),
|
||||
)
|
||||
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( pred, y_list ),
|
||||
precision = self.precision_metric( pred, y_list ),
|
||||
acc = self.accuracy_metric( pred, labels ),
|
||||
precision = self.precision_metric( pred, labels ),
|
||||
)
|
||||
|
||||
|
||||
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
|
||||
|
||||
return answer
|
214
image_classifier/models/lora.py
Normal file
214
image_classifier/models/lora.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
from ..utils import passes_policy
|
||||
|
||||
# LoRA Linear for replacement
|
||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
||||
class LoRALinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
|
||||
rank: int = 4,
|
||||
alpha: int = 1,
|
||||
|
||||
dropout: float = 0.1,
|
||||
merge_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
|
||||
|
||||
self.rank = rank
|
||||
self.alpha = alpha
|
||||
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
||||
self.merge_weights = merge_weights
|
||||
self.merged = False
|
||||
self.enabled = True
|
||||
|
||||
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
|
||||
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
|
||||
self.scaling = self.alpha / self.rank
|
||||
|
||||
self.weight.requires_grad = False
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
super().reset_parameters()
|
||||
# super silly but necessary because nn.Linear's constructor calls this
|
||||
if hasattr(self, 'lora_A'):
|
||||
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
|
||||
nn.init.zeros_( self.lora_B )
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
# training, separate lora from base weights
|
||||
if mode and self.merge_weights and self.merged:
|
||||
self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
|
||||
# not training, merge lora to base weights
|
||||
if not mode and self.merge_weights and not self.merged:
|
||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if not self.merged and self.enabled:
|
||||
result = F.linear(x, self.weight, bias=self.bias)
|
||||
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
||||
return result
|
||||
|
||||
return F.linear(x, self.weight, bias=self.bias)
|
||||
|
||||
@classmethod
|
||||
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
|
||||
if device is None:
|
||||
device = layer.weight.device
|
||||
if dtype is None:
|
||||
dtype = layer.weight.dtype
|
||||
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
# Uses parametrization to inject LoRA weights
|
||||
# Pros: should work with any Linears
|
||||
# Cons: TBD
|
||||
class ParameterizedLoRA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
|
||||
rank: int = 4,
|
||||
alpha: int = 1,
|
||||
|
||||
dropout: float = 0.1,
|
||||
|
||||
device = None,
|
||||
dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
self.rank = rank
|
||||
self.alpha = alpha
|
||||
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
||||
|
||||
self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
|
||||
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
|
||||
self.scaling = self.alpha / self.rank
|
||||
self.enabled = True
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
|
||||
nn.init.zeros_( self.lora_B )
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.enabled:
|
||||
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
|
||||
if device is None:
|
||||
device = layer.weight.device
|
||||
if dtype is None:
|
||||
dtype = layer.weight.dtype
|
||||
# swap because we're feeding the output as our input
|
||||
# M$'s LoRA class arranges things to where this isn't necessary
|
||||
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
|
||||
if device is None:
|
||||
device = layer.weight.device
|
||||
if dtype is None:
|
||||
dtype = layer.weight.dtype
|
||||
|
||||
in_channels, out_channels = layer.weight.shape
|
||||
# swap because we're feeding the output as our input
|
||||
# M$'s LoRA class arranges things to where this isn't necessary
|
||||
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
layer = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance( layer, nn.Linear ):
|
||||
target = nn.Linear
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRALinear
|
||||
replacer = klass.from_linear
|
||||
elif isinstance( layer, nn.Conv1d ):
|
||||
target = nn.Conv1d
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
||||
replacer = klass.from_conv1d
|
||||
elif isinstance( layer, Conv1D ):
|
||||
target = Conv1D
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
||||
replacer = klass.from_conv1d
|
||||
else:
|
||||
continue
|
||||
|
||||
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
|
||||
|
||||
if use_parametrize:
|
||||
parametrize.register_parametrization( layer, "weight", replacement )
|
||||
else:
|
||||
setattr( model.get_submodule(name), k, replacement )
|
||||
|
||||
return enable_lora( model )
|
||||
|
||||
def enable_lora( model, mode = True ):
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
|
||||
continue
|
||||
module.enabled = mode
|
||||
return model
|
||||
|
||||
def disable_lora( model ):
|
||||
return enable_lora( model, False )
|
||||
|
||||
def freeze_non_lora_weights( model, embeddings = False ):
|
||||
frozen_params = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
should = 'lora_' in name or (embeddings and "_emb" in name)
|
||||
|
||||
param.requires_grad_(should)
|
||||
|
||||
if not should:
|
||||
frozen_params.append( param )
|
||||
|
||||
return frozen_params
|
||||
|
||||
def lora_get_state_dict( state_dict, split = True ):
|
||||
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
|
||||
if not split:
|
||||
return lora
|
||||
|
||||
return lora, { name: param for name, param in state_dict.items() if "lora_" not in name }
|
||||
|
||||
def lora_load_state_dict( model, state_dict ):
|
||||
return model.load_state_dict( state_dict, strict = False )
|
120
image_classifier/plot.py
Normal file
120
image_classifier/plot.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from .config import cfg
|
||||
|
||||
def plot(paths, args):
|
||||
dfs = []
|
||||
|
||||
for path in paths:
|
||||
with open(path, "r") as f:
|
||||
text = f.read()
|
||||
|
||||
rows = []
|
||||
|
||||
pattern = r"(\{.+?\})\.\n"
|
||||
|
||||
for row in re.findall(pattern, text, re.DOTALL):
|
||||
try:
|
||||
row = json.loads(row)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
for model in args.models:
|
||||
if f'{model.name}.{args.xs}' not in row:
|
||||
continue
|
||||
rows.append(row)
|
||||
break
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
|
||||
if "name" in df:
|
||||
df["name"] = df["name"].fillna("train")
|
||||
else:
|
||||
df["name"] = "train"
|
||||
|
||||
df["group"] = str(path.parents[args.group_level])
|
||||
df["group"] = df["group"] + "/" + df["name"]
|
||||
|
||||
dfs.append(df)
|
||||
|
||||
df = pd.concat(dfs)
|
||||
|
||||
if args.min_x is not None:
|
||||
for model in args.models:
|
||||
df = df[args.min_x < df[f'{model.name}.{args.xs}']]
|
||||
|
||||
if args.max_x is not None:
|
||||
for model in args.models:
|
||||
df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
|
||||
|
||||
for gtag, gdf in sorted(
|
||||
df.groupby("group"),
|
||||
key=lambda p: (p[0].split("/")[-1], p[0]),
|
||||
):
|
||||
for model in args.models:
|
||||
x = f'{model.name}.{args.xs}'
|
||||
for ys in args.ys:
|
||||
y = f'{model.name}.{ys}'
|
||||
|
||||
if gdf[y].isna().all():
|
||||
continue
|
||||
|
||||
if args.min_y is not None:
|
||||
gdf = gdf[args.min_y < gdf[y]]
|
||||
if args.max_y is not None:
|
||||
gdf = gdf[gdf[y] < args.max_y]
|
||||
|
||||
gdf[y] = gdf[y].ewm(10).mean()
|
||||
|
||||
gdf.plot(
|
||||
x=x,
|
||||
y=y,
|
||||
label=f"{y}",
|
||||
ax=plt.gca(),
|
||||
marker="x" if len(gdf) < 100 else None,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
plt.gca().legend(
|
||||
loc="center left",
|
||||
fancybox=True,
|
||||
shadow=True,
|
||||
bbox_to_anchor=(1.04, 0.5),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--xs", default="engine_step")
|
||||
parser.add_argument("--ys", nargs="+", default="")
|
||||
parser.add_argument("--model", nargs="+", default="*")
|
||||
|
||||
parser.add_argument("--min-x", type=float, default=-float("inf"))
|
||||
parser.add_argument("--min-y", type=float, default=-float("inf"))
|
||||
parser.add_argument("--max-x", type=float, default=float("inf"))
|
||||
parser.add_argument("--max-y", type=float, default=float("inf"))
|
||||
|
||||
parser.add_argument("--filename", default="log.txt")
|
||||
parser.add_argument("--group-level", default=1)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
path = cfg.rel_path / "logs"
|
||||
paths = path.rglob(f"./*/{args.filename}")
|
||||
|
||||
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
|
||||
|
||||
if args.ys == "":
|
||||
args.ys = ["loss"]
|
||||
|
||||
plot(paths, args)
|
||||
|
||||
out_path = cfg.rel_path / "metrics.png"
|
||||
plt.savefig(out_path, bbox_inches="tight")
|
204
image_classifier/samplers.py
Normal file
204
image_classifier/samplers.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from torch import Tensor, einsum, nn
|
||||
|
||||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
# `one_time` will only apply the penalty once
|
||||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
unique = set()
|
||||
priors = reversed(previous)
|
||||
for distance, token in enumerate(priors):
|
||||
# skip if we're only applying the decay once
|
||||
if one_time and token in unique:
|
||||
continue
|
||||
|
||||
distance += 1
|
||||
logits[:, token] /= factor * (distance ** decay)
|
||||
|
||||
# add to set if we care about it
|
||||
if one_time:
|
||||
unique.add(token)
|
||||
|
||||
return logits
|
||||
|
||||
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
|
||||
# `length` is the length of the sequence currently
|
||||
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
|
||||
# `token` is the stop token.
|
||||
def length_penalize( logits, length, factor=0.0, token=-1 ):
|
||||
if factor == 0.0:
|
||||
return logits
|
||||
|
||||
logits[:, token] /= (length ** factor)
|
||||
return logits
|
||||
|
||||
# Simple way to ban tokens
|
||||
def ban_tokens( logits, tokens ):
|
||||
for token in tokens:
|
||||
# token not in logits
|
||||
if logits.shape[-1] >= token:
|
||||
continue
|
||||
logits[:, token] = -float("inf")
|
||||
return logits
|
||||
|
||||
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
Make sure we keep at least min_tokens per batch example in the output
|
||||
"""
|
||||
if top_k > 0:
|
||||
top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
if min_tokens > 1:
|
||||
# Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., :min_tokens] = 0
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
# credit to https://github.com/LostRuins/koboldcpp/pull/464 // https://github.com/kalomaze/koboldcpp/tree/dynamic-temp
|
||||
def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.0, k = 10, sigmoidCenterPoint = 0.5 ):
|
||||
# loop over logits[:], as the NAR will have logits.shape[0] > 1
|
||||
for i in range(logits.shape[0]):
|
||||
sum_exp = 0.0
|
||||
maximum = torch.max( logits[i] )
|
||||
for logit in logits[i]:
|
||||
sum_exp += math.exp( logit - maximum )
|
||||
|
||||
prob_max_token_before_temp = 1.0 / sum_exp
|
||||
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
|
||||
|
||||
logits[i] /= dynamic_temperature
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
# picks the top K tokens amongst a batch of logits
|
||||
# logits: [Tensor] list of logits
|
||||
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
|
||||
def top_k_logits_list( logits_list, k ):
|
||||
# ( batch, tokens ) => ( batch x tokens )
|
||||
logits = torch.cat( logits_list )
|
||||
candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
|
||||
for i, index in enumerate(candidates):
|
||||
t = []
|
||||
N = np.prod(logits.size())
|
||||
for n in logits.size():
|
||||
N //= n
|
||||
t.append(index // N)
|
||||
index %= N
|
||||
candidates[i] = tuple(t)
|
||||
return candidates
|
||||
|
||||
|
||||
# Credit to: https://github.com/basusourya/mirostat/
|
||||
# performs mirostat-based sampling
|
||||
# logits: Tensor of logit probabilities
|
||||
# state: the mirostat state
|
||||
def mirostat_sample( logits, state = None ):
|
||||
def compute_k(prob, n, tau):
|
||||
num = 0
|
||||
den = 0
|
||||
for i in range(100):
|
||||
b = prob[i]/prob[i+1]
|
||||
t = (i+2)/(i+1)
|
||||
num += math.log(b)*math.log(t)
|
||||
den += math.log(t)**2
|
||||
|
||||
s = num/den
|
||||
eps = s-1
|
||||
k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s)
|
||||
k = round(k)
|
||||
return k
|
||||
|
||||
if "max_surprise" not in state:
|
||||
state["max_surprise"] = state["tau"] * 2
|
||||
|
||||
if "error_surprise" not in state:
|
||||
state["error_surprise"] = 0
|
||||
|
||||
if "running_total_surprise" not in state:
|
||||
state["running_total_surprise"] = 0
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True )
|
||||
prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist()
|
||||
|
||||
k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1
|
||||
|
||||
sorted_logits = sorted_logits[0:k]
|
||||
sorted_indices = sorted_indices[0:k]
|
||||
prob_topk = torch.softmax(sorted_logits, dim = 0)
|
||||
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
|
||||
|
||||
state["index_surprise"] = math.log2(1/prob_original[prev_i])
|
||||
state["running_total_surprise"] += state["index_surprise"]
|
||||
state["error_surprise"] = state["index_surprise"] - state["tau"]
|
||||
state["max_surprise"] -= state["eta"] * state["error_surprise"]
|
||||
state["token"] = sorted_indices[prev_i]
|
||||
|
||||
return state
|
||||
|
||||
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||
# performs DRY sampling
|
||||
# * (honestly it looks close to rep pen anyways but what do I know)
|
||||
# `logits` are the scores used to sample against
|
||||
# `previous` are the prior tokens to penalize with
|
||||
# `factor` is the scalar multiplier
|
||||
# `base` is the base number to raise to the (length - allowed_length)th power
|
||||
# `allowed_length` limits the range to apply DRY to
|
||||
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
|
||||
if factor == 0.0 or previous is None:
|
||||
return logits
|
||||
|
||||
lengths = {}
|
||||
for i, token in enumerate( previous ):
|
||||
length = 1
|
||||
while length < max(allowed_length, 50):
|
||||
j = i - length
|
||||
|
||||
# Start of input reached.
|
||||
if j < 0:
|
||||
break
|
||||
|
||||
# Start of match reached.
|
||||
if previous[j] != previous[-length-1]:
|
||||
break
|
||||
|
||||
length += 1
|
||||
|
||||
lengths[token] = max(length, lengths[token]) if token in lengths else length
|
||||
|
||||
for token, length in lengths.items():
|
||||
if length < allowed_length:
|
||||
break
|
||||
logits[:, token] -= factor * base ** (length - allowed_length)
|
||||
|
||||
return logits
|
|
@ -4,7 +4,7 @@ from .config import cfg
|
|||
from .data import create_train_val_dataloader
|
||||
|
||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||
from .utils.trainer import load_engines
|
||||
from .utils.distributed import is_global_leader
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
@ -12,14 +12,22 @@ import random
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import traceback
|
||||
import shutil
|
||||
|
||||
from collections import defaultdict
|
||||
from PIL import Image
|
||||
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train_feeder(engine, batch):
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
batch_size = len(batch["text"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
||||
engine( image=batch["image"], text=batch["text"] )
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
|
@ -31,34 +39,16 @@ def train_feeder(engine, batch):
|
|||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= {k: v.item() for k, v in stat.items()}
|
||||
|
||||
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
||||
|
||||
return loss, stats
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_eval(engines, eval_name, dl):
|
||||
engines_stats = {
|
||||
'eval': eval_name
|
||||
}
|
||||
|
||||
model = None
|
||||
names = []
|
||||
for name, engine in engines.items():
|
||||
names.append(name)
|
||||
model = engine
|
||||
break
|
||||
|
||||
|
||||
stats = defaultdict(list)
|
||||
stats['loss'] = []
|
||||
|
||||
def process( name, batch, resps_list ):
|
||||
for path, ref, hyp in zip(batch["path"], batch["text"], hyp):
|
||||
continue
|
||||
|
||||
for batch in tqdm(dl):
|
||||
batch: dict = to_device(batch, cfg.device)
|
||||
|
||||
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
|
||||
|
||||
def process( name, batch, res, loss ):
|
||||
for path, ref, hyp in zip(batch["path"], batch["text"], res):
|
||||
hyp = hyp.replace('<s>', "").replace("</s>", "")
|
||||
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png")
|
||||
|
@ -67,36 +57,74 @@ def run_eval(engines, eval_name, dl):
|
|||
image = Image.open(path).convert('RGB')
|
||||
image.save(hyp_path)
|
||||
|
||||
stats['loss'].append(loss)
|
||||
|
||||
processed = 0
|
||||
while processed < cfg.evaluation.size:
|
||||
batch = to_device(next(iter(dl)), cfg.device)
|
||||
|
||||
# limit to eval batch size in the event we somehow have a weird dataloader
|
||||
for key in batch.keys():
|
||||
batch[key] = batch[key][:cfg.evaluation.batch_size]
|
||||
|
||||
processed += len(batch["text"])
|
||||
|
||||
for name in engines:
|
||||
engine = engines[name]
|
||||
|
||||
res = engine( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
|
||||
losses = engine.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum().item()
|
||||
|
||||
stats['loss'].append(loss)
|
||||
process( name, batch, res, loss )
|
||||
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
engines_stats.update(flatten_dict({ name: stats }))
|
||||
|
||||
iteration = engines.global_step
|
||||
engines_stats['it'] = iteration
|
||||
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
engines_stats = {
|
||||
f'{name}.{eval_name}': stats,
|
||||
"it": engines.global_step,
|
||||
}
|
||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
||||
|
||||
def main():
|
||||
def train():
|
||||
parser = argparse.ArgumentParser("ResNet Image Classifier")
|
||||
parser.add_argument("--eval", action="store_true", default=None)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# create log folder
|
||||
setup_logging(cfg.log_dir)
|
||||
# copy config yaml to backup
|
||||
if cfg.yaml_path is not None and is_global_leader():
|
||||
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
def eval_fn(engines):
|
||||
do_gc()
|
||||
engines.eval()
|
||||
# wrapped in a try block because it's sometimes prone to breaking
|
||||
try:
|
||||
run_eval(engines, "subtrain", subtrain_dl)
|
||||
run_eval(engines, "val", val_dl)
|
||||
except Exception as e:
|
||||
print("Error occurred while performing eval:", str(e))
|
||||
print(traceback.format_exc())
|
||||
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
||||
_logger.warning(traceback.format_exc())
|
||||
|
||||
engines.train()
|
||||
do_gc()
|
||||
|
||||
if args.eval:
|
||||
return eval_fn(engines=trainer.load_engines())
|
||||
|
||||
"""
|
||||
if cfg.trainer.load_webui:
|
||||
from .webui import start
|
||||
start(lock=False)
|
||||
"""
|
||||
|
||||
trainer.train(
|
||||
train_dl=train_dl,
|
||||
train_feeder=train_feeder,
|
||||
|
@ -104,4 +132,5 @@ def main():
|
|||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
|
||||
train()
|
||||
|
|
|
@ -7,4 +7,7 @@ from .utils import (
|
|||
to_device,
|
||||
tree_map,
|
||||
do_gc,
|
||||
set_seed,
|
||||
passes_policy,
|
||||
get_devices
|
||||
)
|
|
@ -8,6 +8,10 @@ import socket
|
|||
from functools import cache, wraps
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def get_free_port():
|
||||
sock = socket.socket()
|
||||
sock.bind(("", 0))
|
||||
|
@ -15,13 +19,18 @@ def get_free_port():
|
|||
|
||||
|
||||
_distributed_initialized = False
|
||||
def init_distributed( fn ):
|
||||
fn()
|
||||
def init_distributed( fn, *args, **kwargs ):
|
||||
torch.cuda.set_device(local_rank())
|
||||
fn(*args, **kwargs)
|
||||
_distributed_initialized = True
|
||||
|
||||
def distributed_initialized():
|
||||
return _distributed_initialized
|
||||
|
||||
def cleanup_distributed():
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@cache
|
||||
def fix_unset_envs():
|
||||
envs = dict(
|
||||
|
@ -44,10 +53,12 @@ def fix_unset_envs():
|
|||
def local_rank():
|
||||
return int(os.getenv("LOCAL_RANK", 0))
|
||||
|
||||
|
||||
def global_rank():
|
||||
return int(os.getenv("RANK", 0))
|
||||
|
||||
def world_size():
|
||||
return int(os.getenv("WORLD_SIZE", 1))
|
||||
|
||||
|
||||
def is_local_leader():
|
||||
return local_rank() == 0
|
||||
|
@ -87,3 +98,6 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
|
|||
return wrapper
|
||||
|
||||
return wrapper(fn)
|
||||
|
||||
def ddp_model(model):
|
||||
return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)
|
88
image_classifier/utils/io.py
Normal file
88
image_classifier/utils/io.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import torch
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open as sft_load
|
||||
from safetensors.torch import save_file as sft_save
|
||||
|
||||
def coerce_path( path ):
|
||||
return path if isinstance( path, Path ) else Path(path)
|
||||
|
||||
def pick_path( path, *suffixes ):
|
||||
suffixes = [*suffixes]
|
||||
|
||||
for suffix in suffixes:
|
||||
p = path.with_suffix( suffix )
|
||||
if p.exists():
|
||||
return p
|
||||
|
||||
return path
|
||||
|
||||
def is_dict_of( d, t ):
|
||||
if not isinstance( d, dict ):
|
||||
return False
|
||||
|
||||
return all([ isinstance(v, torch.Tensor) for v in d.values() ])
|
||||
|
||||
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
|
||||
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
|
||||
metadata = None
|
||||
|
||||
# is a state_dict, no need to coerce
|
||||
if is_dict_of( data, torch.Tensor ):
|
||||
return data, metadata
|
||||
|
||||
# is maybe a dict with a state dict + metadata, coerce it
|
||||
metadata = {}
|
||||
target = module_key
|
||||
if not target:
|
||||
for k, v in data.items():
|
||||
# is a dict of tensors, our target
|
||||
if is_dict_of( v, torch.Tensor ):
|
||||
target = k
|
||||
continue # continue to iterate to grab other metadata
|
||||
|
||||
# not a dict of tensors, put it as metadata
|
||||
try:
|
||||
metadata[k] = json.dumps(v)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if not target:
|
||||
raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
|
||||
|
||||
return data[target], metadata
|
||||
|
||||
def torch_save( data, path, module_key=None ):
|
||||
path = coerce_path(path)
|
||||
ext = path.suffix
|
||||
|
||||
if ext in [".safetensor", ".sft"]:
|
||||
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
|
||||
|
||||
return sft_save( data, path, metadata )
|
||||
|
||||
return torch.save( data, path )
|
||||
|
||||
def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
|
||||
path = coerce_path(path)
|
||||
ext = path.suffix
|
||||
|
||||
if ext in [".safetensor", ".sft"]:
|
||||
state_dict = {}
|
||||
with sft_load(path, framework=framework, device=device) as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
|
||||
if load_metadata:
|
||||
metadata = f.metadata()
|
||||
for k, v in metadata.items():
|
||||
try:
|
||||
metadata[k] = json.loads( v )
|
||||
except Exception as e:
|
||||
pass
|
||||
state_dict = { module_key: state_dict } | metadata
|
||||
|
||||
return state_dict
|
||||
|
||||
return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )
|
188
image_classifier/utils/sampler.py
Executable file → Normal file
188
image_classifier/utils/sampler.py
Executable file → Normal file
|
@ -1,48 +1,164 @@
|
|||
"""
|
||||
A sampler that balances data by key_fns.
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Zhe Niu
|
||||
|
||||
niuzhe.nz@outlook.com
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
class Sampler:
|
||||
def __init__(self, l, key_fns):
|
||||
self.tree = self._build(l, key_fns)
|
||||
from .distributed import global_rank, local_rank, world_size
|
||||
|
||||
def _build(self, l, key_fns) -> dict[dict, list]:
|
||||
if not key_fns:
|
||||
return l
|
||||
# Randomly picks an index from an array of indices
|
||||
class PoolSampler():
|
||||
def __init__( self, pool = [], keep_all = False, shuffle = False ):
|
||||
self.length = len(pool)
|
||||
self.shuffle = shuffle
|
||||
self.global_pool = pool if keep_all else None
|
||||
self.global_indices = [ i for i in range(self.length) ]
|
||||
self.reset()
|
||||
|
||||
tree = {}
|
||||
def reset(self):
|
||||
self.current_pool = [ i for i in self.global_indices ]
|
||||
if self.shuffle:
|
||||
random.shuffle(self.current_pool)
|
||||
|
||||
key_fn, *key_fns = key_fns
|
||||
def sample(self, pool = None):
|
||||
if pool is None:
|
||||
pool = self.global_pool
|
||||
# check if we need to reset
|
||||
index = random.choice( self.current_pool )
|
||||
# remove from pool
|
||||
self.current_pool.remove(index)
|
||||
# reset if needed
|
||||
if len(self.current_pool) == 0:
|
||||
self.reset()
|
||||
# map indices to our real values
|
||||
return pool[index] if pool is not None else index
|
||||
|
||||
for x in l:
|
||||
k = key_fn(x)
|
||||
def __len__(self):
|
||||
return self.length # len(self.current_pool)
|
||||
|
||||
if k in tree:
|
||||
tree[k].append(x)
|
||||
else:
|
||||
tree[k] = [x]
|
||||
def __iter__(self):
|
||||
while len(self.current_pool) > 0:
|
||||
yield self.sample()
|
||||
|
||||
for k in tree:
|
||||
tree[k] = self._build(tree[k], key_fns)
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.sample(*args, **kwargs)
|
||||
|
||||
return tree
|
||||
def get_state(self):
|
||||
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
|
||||
|
||||
def _sample(self, tree: dict | list):
|
||||
if isinstance(tree, list):
|
||||
ret = random.choice(tree)
|
||||
else:
|
||||
key = random.choice([*tree.keys()])
|
||||
ret = self._sample(tree[key])
|
||||
return ret
|
||||
def set_state(self, state):
|
||||
self.length = state["length"]
|
||||
self.global_pool = state["global_pool"]
|
||||
self.global_indices = state["global_indices"]
|
||||
self.current_pool = state["current_pool"]
|
||||
|
||||
def sample(self):
|
||||
return self._sample(self.tree)
|
||||
# "Samples" through a fixed sequence from 0 to length
|
||||
# Necessary for our "shuffle+sort by duration+interleave" sampling method
|
||||
# Allows saving and loading state
|
||||
class OrderedSampler(Sampler):
|
||||
def __init__( self, length ):
|
||||
self.position = 0
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __iter__(self):
|
||||
if self.position >= self.length:
|
||||
self.position = 0
|
||||
|
||||
while self.position < self.length:
|
||||
yield self.position
|
||||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "length": self.length }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.length = state["length"]
|
||||
|
||||
# Like the above, but will batch based on token count
|
||||
class BatchedOrderedSampler(Sampler):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
|
||||
self.position = 0
|
||||
self.batches = []
|
||||
self.shuffle = shuffle
|
||||
|
||||
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||
|
||||
current_batch = []
|
||||
current_size = 0
|
||||
current_index = 0
|
||||
for key, bucket in buckets.items():
|
||||
for path, duration in bucket:
|
||||
# flush
|
||||
should_flush = False
|
||||
if max_duration > 0 and current_size + duration > max_duration:
|
||||
should_flush = True
|
||||
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
|
||||
should_flush = True
|
||||
|
||||
if should_flush and len(current_batch) > 0:
|
||||
self.batches.append( current_batch )
|
||||
current_batch = []
|
||||
current_size = 0
|
||||
|
||||
current_batch.append( current_index )
|
||||
current_index += 1
|
||||
current_size += duration
|
||||
|
||||
if self.shuffle:
|
||||
random.shuffle(self.batches)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batches)
|
||||
|
||||
def __iter__(self):
|
||||
if self.position >= len(self.batches):
|
||||
self.position = 0
|
||||
if self.shuffle:
|
||||
random.shuffle(self.batches)
|
||||
|
||||
while self.position < len(self.batches):
|
||||
yield self.batches[self.position]
|
||||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "batches": self.batches }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.batches = state["batches"]
|
||||
|
||||
# Randomly samples indices from a given sequence from 0 to length
|
||||
# Allows saving and loading state
|
||||
class RandomSampler(Sampler):
|
||||
def __init__( self, length ):
|
||||
self.position = 0
|
||||
self.length = length
|
||||
|
||||
self.generator = torch.Generator()
|
||||
self.perm = torch.randperm(self.length, generator=self.generator)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __iter__(self):
|
||||
if self.position >= self.length:
|
||||
self.position = 0
|
||||
self.perm = torch.randperm(self.length, generator=self.generator)
|
||||
|
||||
while self.position < self.length:
|
||||
yield self.perm[self.position]
|
||||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.length = state["length"]
|
||||
self.perm = state["perm"]
|
||||
self.generator.set_state(state["generator"])
|
|
@ -4,12 +4,13 @@
|
|||
|
||||
import humanize
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import random
|
||||
import selectors
|
||||
import sys
|
||||
import torch
|
||||
import os
|
||||
|
||||
from functools import cache
|
||||
from torch.distributed import broadcast_object_list
|
||||
|
@ -18,9 +19,10 @@ from tqdm import tqdm
|
|||
from typing import Protocol
|
||||
|
||||
from ..config import cfg
|
||||
from .distributed import init_distributed, distributed_initialized
|
||||
from .distributed import (
|
||||
fix_unset_envs,
|
||||
init_distributed,
|
||||
distributed_initialized,
|
||||
world_size,
|
||||
global_leader_only,
|
||||
global_rank,
|
||||
is_global_leader,
|
||||
|
@ -28,73 +30,15 @@ from .distributed import (
|
|||
local_leader_only,
|
||||
)
|
||||
|
||||
from ..engines import Engine, Engines, TrainFeeder, default_feeder
|
||||
from ..models import get_models
|
||||
from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
|
||||
|
||||
from .utils import to_device, do_gc
|
||||
from .utils import to_device, do_gc, truncate_json
|
||||
from ..utils import wrapper as ml
|
||||
from ..data import get_symmap # should decouple from this trainer script
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_engines: Engines
|
||||
_command: str
|
||||
|
||||
def get_global_step():
|
||||
try:
|
||||
return _engines.global_step
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_micro_step():
|
||||
try:
|
||||
return _engines.micro_step
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_cmd():
|
||||
try:
|
||||
return _command
|
||||
except:
|
||||
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||
|
||||
|
||||
get_iteration = get_global_step
|
||||
|
||||
def load_engines():
|
||||
models = get_models(cfg.models.get())
|
||||
engines = dict()
|
||||
|
||||
for name in models:
|
||||
model = models[name]
|
||||
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
optimizer = ml.AdamW(
|
||||
model.parameters(),
|
||||
lr=cfg.hyperparameters.learning_rate,
|
||||
betas=(0.9, 0.96),
|
||||
eps=1e-07,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
if cfg.trainer.load_state_dict:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
model.load_state_dict(torch.load(load_path))
|
||||
|
||||
engines[name] = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
||||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
|
||||
return engines
|
||||
|
||||
class EvalFn(Protocol):
|
||||
def __call__(self, *, engines: Engines):
|
||||
|
@ -151,17 +95,16 @@ def _non_blocking_input():
|
|||
|
||||
l[0] = s
|
||||
|
||||
if distributed_initialized():
|
||||
if world_size() > 1:
|
||||
broadcast_object_list(l, src=0)
|
||||
_command = l[0]
|
||||
return _command
|
||||
|
||||
|
||||
|
||||
def _make_infinite_epochs(dl):
|
||||
while True:
|
||||
_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
||||
#_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
|
||||
|
||||
|
||||
@local_leader_only(default=None)
|
||||
|
@ -172,30 +115,32 @@ def logger(data):
|
|||
def seed(seed):
|
||||
# Set up random seeds, after fork()
|
||||
random.seed(seed + global_rank())
|
||||
#np.random.seed(seed + global_rank())
|
||||
np.random.seed(seed + global_rank())
|
||||
torch.manual_seed(seed + global_rank())
|
||||
|
||||
|
||||
def train(
|
||||
train_dl: DataLoader,
|
||||
train_feeder: TrainFeeder = default_feeder,
|
||||
eval_fn: EvalFn = lambda x: ...,
|
||||
logger: Logger = logger,
|
||||
):
|
||||
fix_unset_envs()
|
||||
|
||||
engines = load_engines()
|
||||
|
||||
# validate if there's at least one model to train
|
||||
found = False
|
||||
for name, engine in engines.items():
|
||||
if engine.training:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
raise Exception('Training, but no model loaded set to train...')
|
||||
|
||||
"""
|
||||
if is_local_leader():
|
||||
cfg.dump()
|
||||
_logger.info(cfg)
|
||||
"""
|
||||
|
||||
# Setup global engines
|
||||
global _engines
|
||||
_engines = engines
|
||||
|
||||
events = []
|
||||
|
||||
eval_fn = global_leader_only(eval_fn)
|
||||
|
@ -203,15 +148,20 @@ def train(
|
|||
# Pre-loop command
|
||||
command = _non_blocking_input()
|
||||
if command in ["eval", "eval_quit"]:
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
|
||||
if command in ["quit", "eval_quit"]:
|
||||
engines.quit()
|
||||
return
|
||||
|
||||
last_save_step = engines.global_step
|
||||
last_eval_step = 0
|
||||
|
||||
"""
|
||||
if cfg.distributed:
|
||||
train_dl.sampler.set_epoch(int(engines.global_samples / len(train_dl.dataset.paths)))
|
||||
"""
|
||||
|
||||
# Training loop
|
||||
for batch in _make_infinite_epochs(train_dl):
|
||||
if engines.global_step >= cfg.trainer.iterations:
|
||||
|
@ -219,17 +169,15 @@ def train(
|
|||
|
||||
#batch = to_device(batch, torch.cuda.current_device())
|
||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||
|
||||
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
|
||||
stats['it'] = iteration
|
||||
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||
|
||||
del stats['batch_size']
|
||||
del stats['wall_time']
|
||||
del stats['global_step']
|
||||
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
|
||||
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
||||
try:
|
||||
metrics = json.dumps(stats)
|
||||
except Exception as e:
|
||||
metrics = str(stats)
|
||||
|
||||
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
||||
|
||||
command = _non_blocking_input()
|
||||
|
||||
|
@ -267,29 +215,48 @@ def train(
|
|||
|
||||
if "lr" in command:
|
||||
rate = float(command.split(" ")[-1])
|
||||
try:
|
||||
engines.set_lr(rate)
|
||||
print("Updating LR to:", rate)
|
||||
_logger.info(f"Updating LR to: {rate}")
|
||||
except Exception as e:
|
||||
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
|
||||
|
||||
if "export" in command:
|
||||
train_dl.dataset.save_state_dict()
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
if is_global_leader():
|
||||
engines.export(userdata={"symmap": get_symmap()})
|
||||
|
||||
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||
|
||||
saving_commands = ["save"]
|
||||
export_commands = ["export"]
|
||||
|
||||
if cfg.trainer.save_on_quit:
|
||||
saving_commands.append("quit")
|
||||
|
||||
if cfg.trainer.export_on_quit:
|
||||
export_commands.append("quit")
|
||||
|
||||
if cfg.trainer.export_on_save:
|
||||
export_commands.append("save")
|
||||
|
||||
if engines.global_step != last_save_step:
|
||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||
train_dl.dataset.save_state_dict()
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
if command in export_commands and is_global_leader():
|
||||
engines.export(userdata={"symmap": get_symmap()})
|
||||
|
||||
if engines.global_step != last_eval_step:
|
||||
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
||||
do_gc()
|
||||
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
last_eval_step = engines.global_step
|
||||
eval_fn(engines=engines)
|
||||
|
||||
if command in ["quit"]:
|
||||
engines.quit()
|
||||
return
|
|
@ -7,8 +7,16 @@ from .distributed import global_rank, local_rank, global_leader_only
|
|||
import gc
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import re
|
||||
import torch
|
||||
import random
|
||||
import time
|
||||
import psutil
|
||||
import math
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from coloredlogs import ColoredFormatter
|
||||
from logging import StreamHandler
|
||||
|
@ -16,9 +24,16 @@ from pathlib import Path
|
|||
from torch import Tensor, nn
|
||||
from tqdm.auto import tqdm
|
||||
from typing import Callable, TypeVar, overload
|
||||
|
||||
from contextlib import contextmanager
|
||||
T = TypeVar("T")
|
||||
|
||||
def truncate_json( str ):
|
||||
|
||||
def fun( match ):
|
||||
return "{:.4f}".format(float(match.group()))
|
||||
|
||||
return re.sub(r"\d+\.\d{8,}", fun, str)
|
||||
|
||||
def do_gc():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -28,6 +43,14 @@ def flatten_dict(d):
|
|||
return records[0] if records else {}
|
||||
|
||||
|
||||
def set_seed(seed=None):
|
||||
if not seed:
|
||||
seed = int(time.time())
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def _get_named_modules(module, attrname):
|
||||
for name, module in module.named_modules():
|
||||
if hasattr(module, attrname):
|
||||
|
@ -155,5 +178,363 @@ def tree_map(fn: Callable, x):
|
|||
return x
|
||||
|
||||
|
||||
def to_device(x: T, device) -> T:
|
||||
return tree_map(lambda t: t.to(device), x)
|
||||
def to_device(x: T | None, *args, **kwargs) -> T:
|
||||
if x is None:
|
||||
return
|
||||
|
||||
return tree_map(lambda t: t.to(*args, **kwargs), x)
|
||||
|
||||
def coalese( *arg, return_last=True ):
|
||||
return [ x for x in arg if x is not None ][-1 if return_last else 0]
|
||||
|
||||
# checks if a module name is within a given whitelist/blacklist policy dict
|
||||
def passes_policy( policy, name ):
|
||||
if policy is None:
|
||||
return True
|
||||
|
||||
if "exclude" in policy:
|
||||
for term in policy["exclude"]:
|
||||
if term in name:
|
||||
return False
|
||||
|
||||
if "include" in policy:
|
||||
for term in policy["include"]:
|
||||
if term in name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
||||
@contextmanager
|
||||
def autocast(input, from_dtype, to_dtype):
|
||||
if input.dtype == from_dtype:
|
||||
input = input.to(to_dtype)
|
||||
yield input
|
||||
input = input.to(from_dtype)
|
||||
else:
|
||||
yield input
|
||||
|
||||
@contextmanager
|
||||
def autocasts(input, from_dtype, to_dtype):
|
||||
if input.dtype in from_dtype:
|
||||
from_dtype = input.dtype
|
||||
input = input.to(to_dtype)
|
||||
yield input
|
||||
input = input.to(from_dtype)
|
||||
else:
|
||||
yield input
|
||||
|
||||
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
||||
def autocast_forward( func ):
|
||||
def wrapper( self, input, *args, **kwargs ):
|
||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
|
||||
return func( self, k, *args, **kwargs )
|
||||
return wrapper
|
||||
|
||||
# handles migrating an input tensor to a given devicve
|
||||
def auto_align_inputs_forward( module, device=None, name = None ):
|
||||
func = module.forward
|
||||
|
||||
if device is None:
|
||||
if hasattr( module, 'device' ):
|
||||
device = module.device
|
||||
else:
|
||||
try:
|
||||
device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
|
||||
except Exception as e:
|
||||
return func
|
||||
|
||||
|
||||
def wrapper( *args, **kwargs ):
|
||||
args = [*args]
|
||||
# search through args and kwargs for any Tensor arguments
|
||||
for i, arg in enumerate(args):
|
||||
if not isinstance( arg, torch.Tensor ):
|
||||
continue
|
||||
args[i] = arg.to( device=device )
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not isinstance( v, torch.Tensor ):
|
||||
continue
|
||||
kwargs[k] = v.to( device=device )
|
||||
|
||||
# disgusting patch
|
||||
if "position_embeddings" in kwargs:
|
||||
kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
|
||||
|
||||
return func( *args, **kwargs )
|
||||
return wrapper
|
||||
|
||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||
# generalizing this would be super sugoi but the there's no catch all for arguments
|
||||
def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
|
||||
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
||||
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
kwargs = dict(
|
||||
in_features = m.in_features,
|
||||
out_features = m.out_features,
|
||||
bias = m.bias is not None,
|
||||
) if not bnb else dict(
|
||||
input_features=m.in_features,
|
||||
output_features=m.out_features,
|
||||
bias=m.bias is not None,
|
||||
)
|
||||
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||
|
||||
return model
|
||||
|
||||
def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
kwargs = dict(
|
||||
num_embeddings=m.num_embeddings,
|
||||
embedding_dim=m.embedding_dim,
|
||||
padding_idx=m.padding_idx,
|
||||
max_norm=m.max_norm,
|
||||
norm_type=m.norm_type,
|
||||
scale_grad_by_freq=m.scale_grad_by_freq,
|
||||
sparse=m.sparse,
|
||||
)
|
||||
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||
|
||||
return model
|
||||
|
||||
# cannot feasibly do default arguments here sad
|
||||
def replace_attention( model, klass, target, mode="math", verbose=False ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
kwargs = dict(
|
||||
config = m.config,
|
||||
layer_idx = m.layer_idx,
|
||||
mode = mode,
|
||||
)
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||
|
||||
return model
|
||||
|
||||
# trim/expand a tensor (for example, in a state dict)
|
||||
def resize_weight( weight, target, dim=0, random=True ):
|
||||
# trim
|
||||
if target < weight.shape[dim]:
|
||||
return weight[:target]
|
||||
# expand
|
||||
if target > weight.shape[dim]:
|
||||
fn = torch.rand if random else torch.zeros
|
||||
return torch.stack(
|
||||
[ x for x in weight ] +
|
||||
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
|
||||
)
|
||||
|
||||
return weight
|
||||
|
||||
def get_devices():
|
||||
return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu']
|
||||
|
||||
# grabs the memory properties of a given device
|
||||
def get_device_properties( device ):
|
||||
if 'cuda' in device:
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
free, total = torch.cuda.mem_get_info(device)
|
||||
else:
|
||||
props = psutil.virtual_memory()
|
||||
free, total = props.available, props.total
|
||||
|
||||
return {"name": device, "total": total, "free": free, "props": props}
|
||||
|
||||
# gets the rough size for a given module's parameters
|
||||
def get_module_size( module ):
|
||||
param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
|
||||
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
|
||||
return param_size + buffer_size
|
||||
|
||||
# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way
|
||||
# it'd be better to just attach to layers itself rather than every single module
|
||||
|
||||
# assigns modules to requested devices for a given policy
|
||||
def get_model_offload_policy(module, policy=None):
|
||||
# handle any other weird values this is set to
|
||||
if not isinstance(policy, dict):
|
||||
policy = {}
|
||||
|
||||
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
|
||||
if "include" not in policy:
|
||||
policy["include"] = ["model"]
|
||||
|
||||
if "limits" not in policy:
|
||||
policy["limits"] = []
|
||||
|
||||
if "assign" not in policy:
|
||||
policy["assign"] = []
|
||||
|
||||
if "devices" not in policy:
|
||||
policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget
|
||||
|
||||
# create initial device info
|
||||
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
|
||||
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
|
||||
# filter
|
||||
modules = [ (name, size) for name, size in modules if name and size ]
|
||||
|
||||
total_size = sum([size for name, size in modules])
|
||||
|
||||
# set caps if requested in the policy
|
||||
for i, cap in enumerate(policy["limits"]):
|
||||
# no limit, skip
|
||||
if cap <= 0:
|
||||
continue
|
||||
# is fractional, scale to total size
|
||||
if cap < 1:
|
||||
cap = math.floor(total_size * cap)
|
||||
# available space is below cap, don't set
|
||||
if devices[i]["free"] < cap:
|
||||
continue
|
||||
# cap to requested size
|
||||
devices[i]["free"] = cap
|
||||
|
||||
# assign if specific parts of the model are requested for assignment
|
||||
if policy["assign"]:
|
||||
discarded = []
|
||||
# yuck, there has to be a better way
|
||||
for device_index, includes in enumerate( policy["assign"] ):
|
||||
device = devices[device_index]
|
||||
|
||||
buffered_modules = []
|
||||
buffered_size = device["free"]
|
||||
|
||||
# iterate through list of modules to compare against includes
|
||||
for name, size in modules:
|
||||
# doesn't pass policy
|
||||
if not passes_policy( {"include": includes}, name ):
|
||||
continue
|
||||
# check if within budget
|
||||
if buffered_size - size >= 0:
|
||||
# add to buffer
|
||||
buffered_modules.append( (name, size) )
|
||||
buffered_size -= size
|
||||
# budget exceeded, flush buffer
|
||||
else:
|
||||
discarded += buffered_modules
|
||||
buffered_modules = []
|
||||
buffered_size = 0
|
||||
break
|
||||
|
||||
if buffered_modules and buffered_size:
|
||||
device["modules"] += [ name for name, size in buffered_modules ]
|
||||
device["free"] = buffered_size
|
||||
|
||||
modules = discarded
|
||||
|
||||
device_index = 0
|
||||
module_index = 0
|
||||
# assign modules to each device
|
||||
while module_index < len(modules) and device_index < len(devices):
|
||||
device = devices[device_index]
|
||||
name, size = modules[module_index]
|
||||
|
||||
# fits within budget
|
||||
if device["free"] - size >= 0:
|
||||
device["modules"].append( name )
|
||||
device["free"] -= size
|
||||
module_index += 1
|
||||
# does not fit in budget, increase device index
|
||||
else:
|
||||
device_index += 1
|
||||
_logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
|
||||
|
||||
# to-do: check that all modules are exhausted
|
||||
assert module_index >= len(modules)
|
||||
|
||||
# only return devices with modules assigned
|
||||
return [ device for device in devices if device["modules"] ]
|
||||
|
||||
# handles naively splitting a model's layers across multiple devices
|
||||
# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
|
||||
def offload_model( model, policy=None ):
|
||||
policy = get_model_offload_policy(model, policy=policy)
|
||||
|
||||
# move modules to respective devices
|
||||
for i, device in enumerate( policy ):
|
||||
# nothing assigned, skip
|
||||
if not device["modules"]:
|
||||
continue
|
||||
|
||||
for name in device["modules"]:
|
||||
module = model.get_submodule(name)
|
||||
module = module.to( device["name"] )
|
||||
module.device = device['name']
|
||||
|
||||
# wrap modules with forward to ensure all inputs are matched to its device
|
||||
for name, module in model.named_modules():
|
||||
if not hasattr( module, 'forward' ):
|
||||
continue
|
||||
|
||||
module.forward = auto_align_inputs_forward(module)
|
||||
|
||||
"""
|
||||
# Validate that the layers are all in the right spot
|
||||
for name, module in model.named_modules():
|
||||
if not not [*module.named_children()]:
|
||||
continue
|
||||
try:
|
||||
_logger.info( name, next(module.parameters()).device )
|
||||
except Exception as e:
|
||||
_logger.info( name, "?" )
|
||||
pass
|
||||
"""
|
||||
|
||||
return model
|
|
@ -1,20 +1,39 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import logging
|
||||
|
||||
from ..config import cfg
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
Embedding = torch.nn.Embedding
|
||||
Linear = torch.nn.Linear
|
||||
|
||||
if cfg.bitsandbytes.enabled:
|
||||
Adam = torch.optim.Adam
|
||||
AdamW = torch.optim.AdamW
|
||||
SGD = torch.optim.SGD
|
||||
Adagrad = torch.optim.Adagrad
|
||||
|
||||
# https://github.com/kyegomez/BitNet
|
||||
if cfg.optimizations.bitnet:
|
||||
from bitnet import BitLinear
|
||||
|
||||
if cfg.optimizations.bitsandbytes:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
if cfg.bitsandbytes.linear:
|
||||
if cfg.optimizations.linear:
|
||||
|
||||
if cfg.optimizations.bitnet:
|
||||
Linear = BitLinear
|
||||
else:
|
||||
Linear = bnb.nn.Linear8bitLt
|
||||
|
||||
if cfg.bitsandbytes.embedding:
|
||||
Embedding = bnb.nn.StableEmbedding
|
||||
if cfg.optimizations.embedding:
|
||||
Embedding = bnb.nn.modules.Embedding
|
||||
"""
|
||||
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
||||
input,
|
||||
self.weight,
|
||||
|
@ -24,52 +43,101 @@ if cfg.bitsandbytes.enabled:
|
|||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)).to(self.weight.dtype) )
|
||||
"""
|
||||
|
||||
Adam = torch.optim.Adam
|
||||
AdamW = torch.optim.AdamW
|
||||
if cfg.optimizations.optimizers:
|
||||
Adam = bnb.optim.Adam8bit
|
||||
AdamW = bnb.optim.AdamW8bit
|
||||
SGD = bnb.optim.SGD8bit
|
||||
Adagrad = bnb.optim.Adagrad8bit
|
||||
|
||||
if cfg.bitsandbytes.enabled:
|
||||
import bitsandbytes as bnb
|
||||
elif cfg.optimizations.dadaptation:
|
||||
import dadaptation
|
||||
|
||||
Adam = bnb.optim.Adam
|
||||
AdamW = bnb.optim.AdamW
|
||||
if cfg.optimizations.optimizers:
|
||||
Adam = dadaptation.DAdaptAdam
|
||||
AdamW = dadaptation.DAdaptAdam
|
||||
SGD = dadaptation.DAdaptSGD
|
||||
AdaGrad = dadaptation.DAdaptAdaGrad
|
||||
|
||||
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
||||
@contextmanager
|
||||
def autocast(input, from_dtype, to_dtype):
|
||||
if input.dtype == from_dtype:
|
||||
input = input.to(to_dtype)
|
||||
yield input
|
||||
input = input.to(from_dtype)
|
||||
else:
|
||||
yield input
|
||||
if cfg.optimizations.fp8:
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
Linear = te.Linear
|
||||
|
||||
@contextmanager
|
||||
def autocasts(input, from_dtype, to_dtype):
|
||||
if input.dtype in from_dtype:
|
||||
from_dtype = input.dtype
|
||||
input = input.to(to_dtype)
|
||||
yield input
|
||||
input = input.to(from_dtype)
|
||||
def autocast():
|
||||
yield te.fp8_autocast(enabled=True)
|
||||
else:
|
||||
yield input
|
||||
@contextmanager
|
||||
def autocast():
|
||||
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
||||
|
||||
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
||||
def autocast_forward( func ):
|
||||
def wrapper( self, input, *args, **kwargs ):
|
||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
|
||||
return func( self, k, *args, **kwargs )
|
||||
"""
|
||||
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
|
||||
return func( self, input.to(torch.int32), *args, **kwargs )
|
||||
return func( self, input, *args, **kwargs )
|
||||
"""
|
||||
return wrapper
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
||||
|
||||
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
||||
if cfg.optimizations.injects:
|
||||
if cfg.optimizations.linear:
|
||||
torch.nn.Linear = Linear
|
||||
|
||||
if cfg.optimizations.embedding:
|
||||
torch.nn.Embedding = Embedding
|
||||
|
||||
if cfg.optimizations.optimizers:
|
||||
torch.optim.Adam = Adam
|
||||
torch.optim.AdamW = AdamW
|
||||
torch.optim.SGD = SGD
|
||||
|
||||
AVAILABLE_COMPILE_BACKENDS = []
|
||||
|
||||
try:
|
||||
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
if cfg.optimizations.tensorrt:
|
||||
try:
|
||||
import torch_tensorrt
|
||||
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing TensorRT: {str(e)}')
|
||||
pass
|
||||
|
||||
def compile_model(model, backend="auto"):
|
||||
if not backend or backend == "auto":
|
||||
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
||||
|
||||
if backend not in AVAILABLE_COMPILE_BACKENDS:
|
||||
return torch.compile(model)
|
||||
|
||||
return torch.compile(model, backend=backend)
|
||||
|
||||
# https://github.com/konstmish/prodigy
|
||||
try:
|
||||
from prodigyopt import Prodigy
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing Prodigyopt: {str(e)}')
|
||||
pass
|
||||
|
||||
# https://github.com/facebookresearch/schedule_free/
|
||||
try:
|
||||
import schedulefree
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing Schedule_Free: {str(e)}')
|
||||
pass
|
||||
|
||||
# backwards compat
|
||||
from .utils import (
|
||||
autocast_forward,
|
||||
replace_linear as replace_linear_old,
|
||||
replace_embedding as replace_embedding_old,
|
||||
replace_attention,
|
||||
resize_weight,
|
||||
offload_model,
|
||||
)
|
||||
|
||||
# wrapped here so we can maintain default args
|
||||
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
||||
return replace_linear_old( model, klass, target, verbose )
|
||||
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
||||
return replace_embedding_old( model, klass, target, verbose )
|
||||
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
220
image_classifier/webui.py
Normal file
220
image_classifier/webui.py
Normal file
|
@ -0,0 +1,220 @@
|
|||
import os
|
||||
import re
|
||||
import argparse
|
||||
import random
|
||||
import tempfile
|
||||
import functools
|
||||
from datetime import datetime
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from time import perf_counter
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
from .inference import Classifier, cfg
|
||||
from .train import train
|
||||
from .utils import get_devices
|
||||
|
||||
classifier = None
|
||||
|
||||
layout = {}
|
||||
layout["inference"] = {}
|
||||
layout["training"] = {}
|
||||
layout["settings"] = {}
|
||||
|
||||
for k in layout.keys():
|
||||
layout[k]["inputs"] = { "progress": None }
|
||||
layout[k]["outputs"] = {}
|
||||
layout[k]["buttons"] = {}
|
||||
|
||||
# there's got to be a better way to go about this
|
||||
def gradio_wrapper(inputs):
|
||||
def decorated(fun):
|
||||
@functools.wraps(fun)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
for i, key in enumerate(inputs):
|
||||
kwargs[key] = args[i]
|
||||
try:
|
||||
return fun(**kwargs)
|
||||
except Exception as e:
|
||||
raise gr.Error(str(e))
|
||||
return wrapped_function
|
||||
return decorated
|
||||
|
||||
class timer:
|
||||
def __init__(self, msg="Elapsed time:"):
|
||||
self.msg = msg
|
||||
|
||||
def __enter__(self):
|
||||
self.start = perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
|
||||
|
||||
gr.Info(msg)
|
||||
print(f'[{datetime.now().isoformat()}] {msg}')
|
||||
|
||||
# returns a list of models, assuming the models are placed under ./training/ or ./models/
|
||||
def get_model_paths( paths=[Path("./data/"), Path("./training/"), Path("./models/")] ):
|
||||
yamls = []
|
||||
|
||||
for path in paths:
|
||||
if not path.exists():
|
||||
continue
|
||||
|
||||
for yaml in path.glob("**/*.yaml"):
|
||||
if "/logs/" in str(yaml):
|
||||
continue
|
||||
|
||||
yamls.append( yaml )
|
||||
|
||||
return yamls
|
||||
|
||||
def get_dtypes():
|
||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||
|
||||
|
||||
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
||||
def load_model( yaml, device, dtype ):
|
||||
gr.Info(f"Loading: {yaml}")
|
||||
try:
|
||||
init_classifier( yaml=Path(yaml), restart=True, device=device, dtype=dtype )
|
||||
except Exception as e:
|
||||
raise gr.Error(e)
|
||||
gr.Info(f"Loaded model")
|
||||
|
||||
def init_classifier(yaml=None, restart=False, device="cuda", dtype="auto"):
|
||||
global classifier
|
||||
|
||||
if classifier is not None:
|
||||
if not restart:
|
||||
return classifier
|
||||
|
||||
del classifier
|
||||
classifier = None
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--device", type=str, default=device)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=dtype)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
classifier = Classifier( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||
return classifier
|
||||
|
||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
if not cfg.yaml_path:
|
||||
raise Exception("No YAML loaded.")
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
# I'm very sure I can procedurally generate this list
|
||||
parser.add_argument("--image", type=str, default=kwargs["image"])
|
||||
parser.add_argument("--temp", type=float, default=kwargs["temp"])
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
classifier = init_classifier()
|
||||
|
||||
args.image = Image.open(args.image).convert('RGB')
|
||||
|
||||
gr.Info("Inferencing...")
|
||||
with timer("Inferenced in") as t:
|
||||
answer = classifier.inference(
|
||||
image=args.image,
|
||||
temperature=args.temp,
|
||||
)
|
||||
|
||||
return answer
|
||||
|
||||
# setup args
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
||||
parser.add_argument("--share", action="store_true")
|
||||
parser.add_argument("--render_markdown", action="store_true", default="CLASSIFIER_YAML" in os.environ)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
args.listen_host = None
|
||||
args.listen_port = None
|
||||
args.listen_path = None
|
||||
if args.listen:
|
||||
try:
|
||||
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
|
||||
|
||||
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
|
||||
args.listen_port = match[1] if match[1] != "" else None
|
||||
args.listen_path = match[2] if match[2] != "" else "/"
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if args.listen_port is not None:
|
||||
args.listen_port = int(args.listen_port)
|
||||
if args.listen_port == 0:
|
||||
args.listen_port = None
|
||||
|
||||
# setup gradio
|
||||
ui = gr.Blocks()
|
||||
with ui:
|
||||
with gr.Tab("Inference"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
layout["inference"]["inputs"]["image"] = gr.Image(label="Input Image", sources=["upload"], type="filepath")
|
||||
layout["inference"]["outputs"]["output"] = gr.Textbox(label="Output")
|
||||
with gr.Column(scale=4):
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature", info="Modifies the randomness from the samples. (0 to greedy sample)")
|
||||
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||
|
||||
layout["inference"]["buttons"]["inference"].click(
|
||||
fn=do_inference,
|
||||
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
|
||||
)
|
||||
|
||||
"""
|
||||
with gr.Tab("Training"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
|
||||
|
||||
layout["training"]["buttons"]["train"].click(
|
||||
fn=do_training,
|
||||
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
|
||||
)
|
||||
"""
|
||||
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||
with gr.Column(scale=1):
|
||||
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||
|
||||
layout["settings"]["buttons"]["load"].click(
|
||||
fn=load_model,
|
||||
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
|
||||
)
|
||||
|
||||
if os.path.exists("README.md") and args.render_markdown:
|
||||
md = open("README.md", "r", encoding="utf-8").read()
|
||||
# remove HF's metadata
|
||||
if md.startswith("---\n"):
|
||||
md = "".join(md.split("---")[2:])
|
||||
gr.Markdown(md)
|
||||
|
||||
def start( lock=True ):
|
||||
ui.queue(max_size=8)
|
||||
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
|
||||
|
||||
if __name__ == "__main__":
|
||||
start()
|
0
scripts/run.sh
Executable file → Normal file
0
scripts/run.sh
Executable file → Normal file
36
setup.py
36
setup.py
|
@ -1,5 +1,5 @@
|
|||
import subprocess
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from setuptools import setup, find_packages
|
||||
|
@ -8,7 +8,6 @@ def shell(*args):
|
|||
out = subprocess.check_output(args)
|
||||
return out.decode("ascii").strip()
|
||||
|
||||
|
||||
def write_version(version_core, pre_release=True):
|
||||
if pre_release:
|
||||
time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
|
||||
|
@ -23,8 +22,7 @@ def write_version(version_core, pre_release=True):
|
|||
|
||||
return version
|
||||
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as f:
|
||||
with open("README.md", "r") as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
|
@ -37,17 +35,37 @@ setup(
|
|||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
install_requires=(
|
||||
# training backends
|
||||
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else [])
|
||||
+ [
|
||||
# logging niceties
|
||||
"coloredlogs>=15.0.1",
|
||||
"humanize>=4.4.0",
|
||||
"matplotlib>=3.6.0",
|
||||
"pandas>=1.5.0",
|
||||
|
||||
# boiler plate niceties
|
||||
"diskcache>=5.4.0",
|
||||
"einops>=0.6.0",
|
||||
"omegaconf==2.0.6",
|
||||
"tqdm>=4.64.1",
|
||||
"humanize>=4.4.0",
|
||||
"tqdm",
|
||||
|
||||
"pandas>=1.5.0",
|
||||
# HF bloat
|
||||
"tokenizers",
|
||||
"transformers",
|
||||
"safetensors",
|
||||
|
||||
# training bloat
|
||||
"h5py",
|
||||
"prodigyopt @ git+https://github.com/konstmish/prodigy",
|
||||
|
||||
# practically the reason to use python
|
||||
"numpy",
|
||||
"torch>=1.13.0",
|
||||
"torchmetrics",
|
||||
|
||||
"simple_http_server",
|
||||
"pillow"
|
||||
],
|
||||
url="https://git.ecker.tech/mrq/resnet-classifier",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user