updated framework to use the saner framework that mrq/vall-e uses these days
This commit is contained in:
parent
5cb28a210e
commit
5610bb3bb3
0
.gitignore
vendored
Executable file → Normal file
0
.gitignore
vendored
Executable file → Normal file
40
README.md
Executable file → Normal file
40
README.md
Executable file → Normal file
|
@ -1,10 +1,10 @@
|
|||
# Tentative Title For A ResNet-Based Image Classifier
|
||||
|
||||
This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
|
||||
This is a simple ResNet based image classifier for images, using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
|
||||
|
||||
## Premise
|
||||
|
||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
|
||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I faced.
|
||||
|
||||
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
|
||||
|
||||
|
@ -16,44 +16,14 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
|
|||
|
||||
3. Install using `pip3 install -e ./image_classifier/`.
|
||||
|
||||
4. Train using `python3 -m image_classifier.train yaml='./data/config.yaml'`.
|
||||
4. Train using `python3 -m image_classifier.train --yaml='./data/config.yaml'`.
|
||||
|
||||
5. Wait.
|
||||
|
||||
## Inferencing
|
||||
|
||||
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
||||
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" --yaml="./data/config.yaml"`
|
||||
|
||||
### Continuous Usage
|
||||
|
||||
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
|
||||
|
||||
## Known Issues
|
||||
|
||||
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
|
||||
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
|
||||
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
|
||||
|
||||
## Strawmen
|
||||
|
||||
>\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!!
|
||||
|
||||
I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests.
|
||||
|
||||
>\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]`
|
||||
|
||||
Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader).
|
||||
|
||||
Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there.
|
||||
|
||||
>\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`??
|
||||
|
||||
I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment.
|
||||
|
||||
>\> UGH!!! What are you talking about """specific images"""???
|
||||
|
||||
[ひみつ](https://files.catbox.moe/csuh49.webm)
|
||||
|
||||
>\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`???
|
||||
|
||||
:)
|
||||
If you're looking to continuously classify images, use `python3 -m image_classifier --listen --port=7860 --yaml="./data/config.yaml"` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
|
131
data/config.yaml
Executable file → Normal file
131
data/config.yaml
Executable file → Normal file
|
@ -1,85 +1,84 @@
|
|||
dataset:
|
||||
training: [
|
||||
"./data/images/"
|
||||
]
|
||||
|
||||
validation: []
|
||||
|
||||
use_hdf5: False
|
||||
|
||||
workers: 0
|
||||
cache: True
|
||||
weights_format: sft
|
||||
|
||||
models:
|
||||
_models:
|
||||
- name: "classifier"
|
||||
tokens: 0
|
||||
len: 6
|
||||
- name: "classifier"
|
||||
tokens: 0
|
||||
len: 6
|
||||
dim: 512
|
||||
resnet: 34
|
||||
#loras:
|
||||
#- name : "lora"
|
||||
# rank: 128
|
||||
# alpha: 128
|
||||
# training: True
|
||||
# rvq_levels: []
|
||||
|
||||
hyperparameters:
|
||||
batch_size: 256
|
||||
gradient_accumulation_steps: 64
|
||||
gradient_clipping: 100
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps: 10
|
||||
|
||||
optimizer: Prodigy
|
||||
learning_rate: 1.0
|
||||
torch_optimizer: True
|
||||
|
||||
optimizer: Adamw
|
||||
learning_rate: 1.0e-3
|
||||
|
||||
scheduler_type: ""
|
||||
#scheduler_type: OneCycle
|
||||
#scheduler_params:
|
||||
# cycle_first_step_size: 10_000
|
||||
# cycle_first_stair_count: 10_000
|
||||
|
||||
# cycle_second_step_size: 15_000
|
||||
# cycle_second_stair_count: 15_000
|
||||
|
||||
# decay_step_size: 5_000
|
||||
|
||||
# cycle_min_lr: 2.5e-4 # 1.0e-5
|
||||
# cycle_max_lr: 2.5e-4 # 1.0e-4
|
||||
# decay_lr_rate: 0.0
|
||||
|
||||
# cycle_min_mom: 0.90
|
||||
# cycle_max_mom: 0.99
|
||||
# decay_mom_rate: 0.0
|
||||
scheduler: "" # ScheduleFree
|
||||
torch_scheduler: True
|
||||
|
||||
evaluation:
|
||||
batch_size: 32
|
||||
frequency: 250
|
||||
size: 32
|
||||
batch_size: 64
|
||||
frequency: 100
|
||||
size: 64
|
||||
|
||||
steps: 300
|
||||
temperature: 1.0
|
||||
steps: 450
|
||||
temperature: 0.0
|
||||
|
||||
trainer:
|
||||
iterations: 100_000
|
||||
|
||||
save_tag: step
|
||||
save_on_oom: True
|
||||
save_on_quit: True
|
||||
iterations: 1_000_000
|
||||
save_frequency: 100
|
||||
|
||||
aggressive_optimizations: False
|
||||
|
||||
check_for_oom: False
|
||||
|
||||
#load_tag: "9500"
|
||||
#load_state_dict: True
|
||||
#load_states: False
|
||||
#strict_loading: False
|
||||
#restart_step_count: True
|
||||
keep_last_checkpoints: 32
|
||||
|
||||
gc_mode: None # "global_step"
|
||||
check_for_oom: False
|
||||
gradient_checkpointing: True
|
||||
|
||||
weight_dtype: float32
|
||||
weight_dtype: bfloat16
|
||||
amp: True
|
||||
|
||||
backend: local
|
||||
backend: deepspeed
|
||||
deepspeed:
|
||||
zero_optimization_level: 0
|
||||
use_compression_training: True
|
||||
inferencing: False
|
||||
amp: False
|
||||
|
||||
inference:
|
||||
use_vocos: True
|
||||
backend: local
|
||||
|
||||
bitsandbytes:
|
||||
enabled: false
|
||||
weight_dtype: bfloat16
|
||||
amp: True
|
||||
|
||||
optimizations:
|
||||
injects: False
|
||||
replace: True
|
||||
|
||||
linear: False
|
||||
embedding: False
|
||||
optimizers: True
|
||||
|
||||
bitsandbytes: False
|
||||
dadaptation: False
|
||||
bitnet: False
|
||||
fp8: False
|
||||
|
||||
dataset:
|
||||
use_hdf5: True
|
||||
hdf5_flag: r
|
||||
|
||||
workers: 1
|
||||
cache: True
|
||||
|
||||
training: [
|
||||
"./data/images/"
|
||||
]
|
||||
validation: [
|
||||
"./data/validation/"
|
||||
]
|
|
@ -12,35 +12,55 @@ def main():
|
|||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--listen", action='store_true')
|
||||
parser.add_argument("--port", type=int, default=9090)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--ckpt", type=Path, default=None)
|
||||
parser.add_argument("--temp", type=float, default=1.0)
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=None)
|
||||
|
||||
parser.add_argument("--temp", type=float, default=0.0)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
|
||||
classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
if args.listen:
|
||||
@route("/")
|
||||
def inference( b64, temperature=1.0 ):
|
||||
def inference( b64, temperature=args.temp ):
|
||||
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
|
||||
return { "answer": classifier.inference( image=image, temperature=args.temp ) }
|
||||
return { "answer": classifier.inference( image=image, temperature=temperature ) }
|
||||
server.start(port=args.port)
|
||||
else:
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--path", type=Path)
|
||||
parser.add_argument("--base64", type=str)
|
||||
parser.add_argument("--write", type=Path)
|
||||
parser.add_argument("--temp", type=float, default=1.0)
|
||||
args, unknown = parser.parse_known_args()
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
images = []
|
||||
if args.path:
|
||||
image = Image.open(args.path).convert('RGB')
|
||||
if args.path.is_dir():
|
||||
for p in args.path.rglob("./*.jpg"):
|
||||
image = Image.open(p).convert('RGB')
|
||||
images.append(image)
|
||||
for p in args.path.rglob("./*.png"):
|
||||
image = Image.open(p).convert('RGB')
|
||||
images.append(image)
|
||||
else:
|
||||
image = Image.open(args.path).convert('RGB')
|
||||
images.append(image)
|
||||
elif args.base64:
|
||||
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
|
||||
images.append(image)
|
||||
else:
|
||||
raise "Specify a --path or --base64."
|
||||
|
||||
answer = classifier.inference( image=image, temperature=args.temp )
|
||||
print("Answer:", answer)
|
||||
for image in images:
|
||||
answer = classifier.inference( image=image, temperature=args.temp )
|
||||
print("Answer:", answer)
|
||||
if args.write:
|
||||
args.write.mkdir(exist_ok=True)
|
||||
image.save( args.write / f"{answer}.jpg")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -6,31 +6,61 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from functools import cached_property, cache
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
import argparse
|
||||
import yaml
|
||||
import random
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
|
||||
from .utils.distributed import world_size
|
||||
|
||||
|
||||
def set_seed(seed=None):
|
||||
if not seed:
|
||||
seed = time.time()
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@dataclass()
|
||||
class _Config:
|
||||
cfg_path: str | None = None
|
||||
class BaseConfig:
|
||||
yaml_path: str | None = None # path passed in through --yaml
|
||||
|
||||
@property
|
||||
def relpath(self):
|
||||
def cfg_path(self):
|
||||
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
|
||||
|
||||
@property
|
||||
def rel_path(self):
|
||||
return Path(self.cfg_path)
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return self.rel_path / ".cache"
|
||||
|
||||
@property
|
||||
def data_dir(self):
|
||||
return self.rel_path / "data"
|
||||
|
||||
@property
|
||||
def metadata_dir(self):
|
||||
return self.rel_path / "metadata"
|
||||
|
||||
@property
|
||||
def ckpt_dir(self):
|
||||
return self.relpath / "ckpt"
|
||||
return self.rel_path / "ckpt"
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
return self.relpath / "logs" / str(self.start_time)
|
||||
return self.rel_path / "logs" / str(self.start_time)
|
||||
|
||||
@cached_property
|
||||
def start_time(self):
|
||||
|
@ -64,39 +94,28 @@ class _Config:
|
|||
with open(path, "w") as f:
|
||||
f.write(self.dumps())
|
||||
|
||||
@staticmethod
|
||||
def _is_cfg_argv(s):
|
||||
return "=" in s and "--" not in s
|
||||
|
||||
@classmethod
|
||||
def from_yaml( cls, yaml_path ):
|
||||
return cls.from_cli( [f'yaml="{yaml_path}"'] )
|
||||
state = {}
|
||||
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
||||
state.setdefault("yaml_path", yaml_path)
|
||||
return cls(**state)
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, args=sys.argv):
|
||||
cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)])
|
||||
# legacy support for yaml=`` format
|
||||
for i, arg in enumerate(args):
|
||||
if arg.startswith("yaml"):
|
||||
args[i] = f'--{arg}'
|
||||
|
||||
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
|
||||
sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)]
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
args, unknown = parser.parse_known_args(args=args)
|
||||
|
||||
if cli_cfg.get("help"):
|
||||
print(f"Configurable hyperparameters with their default values:")
|
||||
print(json.dumps(asdict(cls()), indent=2, default=str))
|
||||
exit()
|
||||
if args.yaml:
|
||||
return cls.from_yaml( args.yaml )
|
||||
|
||||
if "yaml" in cli_cfg:
|
||||
yaml_cfg = OmegaConf.load(cli_cfg.yaml)
|
||||
yaml_path = Path(cli_cfg.yaml).absolute()
|
||||
cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1])
|
||||
cfg_path = cfg_path.with_suffix("")
|
||||
cfg_path = f'./{cfg_path}'
|
||||
|
||||
yaml_cfg.setdefault("cfg_path", cfg_path)
|
||||
cli_cfg.pop("yaml")
|
||||
else:
|
||||
yaml_cfg = {}
|
||||
merged = OmegaConf.merge(yaml_cfg, cli_cfg)
|
||||
return cls(**dict(merged))
|
||||
return cls(**{})
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
@ -106,104 +125,195 @@ class _Config:
|
|||
|
||||
@dataclass()
|
||||
class Dataset:
|
||||
training: list[Path] = field(default_factory=lambda: [])
|
||||
validation: list[Path] = field(default_factory=lambda: [])
|
||||
|
||||
temp: list[Path] = field(default_factory=lambda: [])
|
||||
|
||||
# de-implemented, because the data isn't that large to facilitate HDF5
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
|
||||
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
|
||||
|
||||
workers: int = 8
|
||||
cache: bool = True
|
||||
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
|
||||
use_hdf5: bool = False # whether to load from an HDF5 dataset
|
||||
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
|
||||
|
||||
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
||||
workers: int = 8 # number of dataloader workers to spawn
|
||||
cache: bool = True # use diskcache to cache the dataset
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
name: str = ""
|
||||
name: str = "classifier"
|
||||
|
||||
tokens: int = 0 # number of token types
|
||||
len: int = 1 # how long a sequence can be
|
||||
dim: int = 512
|
||||
resnet: int = 18
|
||||
|
||||
width: int = 300
|
||||
height: int = 80
|
||||
|
||||
version: int = 1
|
||||
training: bool = True
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
return self.name
|
||||
|
||||
@dataclass()
|
||||
class Models:
|
||||
_models: list[Model] = field(default_factory=lambda: [
|
||||
Model(name="captcha"),
|
||||
])
|
||||
|
||||
def get(self, name=None):
|
||||
if not name:
|
||||
return [ Model(**model) for model in self._models ]
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
||||
def loss_factor(self, k):
|
||||
return self.loss_factors[k] if k in self.loss_factors else 1.0
|
||||
|
||||
for model in self._models:
|
||||
if model.name == name:
|
||||
return model
|
||||
@property
|
||||
# required for fp8 as the lengths needs to be divisible by 8
|
||||
def input_alignment(self):
|
||||
return 8 if cfg.optimizations.fp8 else 0
|
||||
|
||||
raise ValueError
|
||||
@property
|
||||
def activation_checkpointing(self):
|
||||
return cfg.trainer.activation_checkpointing
|
||||
|
||||
@property
|
||||
def gradient_checkpointing(self):
|
||||
return cfg.trainer.gradient_checkpointing
|
||||
|
||||
@property
|
||||
def lora_policy(self):
|
||||
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
|
||||
exclude = []
|
||||
|
||||
if self.arch_type == "llama":
|
||||
include = ["self_attn", "mlp"] # target only the attention + mlp
|
||||
exclude = ["self_attn.k_proj"] # common literature says to ignore it
|
||||
if self.arch_type == "retnet":
|
||||
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
|
||||
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
|
||||
|
||||
return dict(include=include, exclude=exclude)
|
||||
|
||||
# should be renamed to Adapters
|
||||
@dataclass()
|
||||
class LoRA:
|
||||
name: str = "lora" # vanity name
|
||||
# to-do: find sane default values
|
||||
rank: int = 128 # rank for the LoRA
|
||||
alpha: int = 128 # rank for the LoRA
|
||||
training: bool = True #
|
||||
embeddings: bool = False # train the embedding too
|
||||
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
|
||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
||||
return "-".join(name)
|
||||
|
||||
# actually not needed anymore
|
||||
def active_level( self, level ):
|
||||
if not self.rvq_levels:
|
||||
return True
|
||||
return level in self.rvq_levels
|
||||
|
||||
@dataclass()
|
||||
class Hyperparameters:
|
||||
batch_size: int = 8
|
||||
gradient_accumulation_steps: int = 32
|
||||
gradient_clipping: int = 100 # to be implemented in the local backend
|
||||
batch_size: int = 8 # number of samples per training batch
|
||||
gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
|
||||
gradient_clipping: int | float = 10 # largest size a gradient norm can be
|
||||
|
||||
optimizer: str = "Adamw"
|
||||
learning_rate: float = 3.25e-4
|
||||
optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
|
||||
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
|
||||
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
|
||||
|
||||
scheduler_type: str = "" # to be implemented in the local backend
|
||||
scheduler_params: dict = field(default_factory=lambda: {})
|
||||
scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
|
||||
scheduler_type: str = "" # deprecated
|
||||
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
autotune: bool = False # to do deepspeed's autotuning
|
||||
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
|
||||
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
|
||||
|
||||
@dataclass()
|
||||
class Evaluation:
|
||||
batch_size: int = 64
|
||||
frequency: int = 250
|
||||
size: int = 64
|
||||
|
||||
batch_size: int = 64 # number of samples per batch during eval / val
|
||||
frequency: int = 250 # do eval / val every X iterations
|
||||
size: int = 64 # number of samples to generate during eval / val
|
||||
|
||||
steps: int = 500
|
||||
temperature: float = 1.0
|
||||
temperature: float = 1.0 # AR temp for inferencing
|
||||
|
||||
load_disabled_engines: bool = True # see the other load_disabled_engines
|
||||
|
||||
@dataclass()
|
||||
class DeepSpeed:
|
||||
zero_optimization_level: int = 0
|
||||
use_compression_training: bool = False
|
||||
zero_optimization_level: int = 0 # doesn't seem to work
|
||||
use_compression_training: bool = False # cope
|
||||
compression_bits: int = 8 # cope
|
||||
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||
|
||||
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
||||
|
||||
def get_ds_cfg(self, model):
|
||||
weights = [ name[0] for name in model.named_parameters() ]
|
||||
bits = 8
|
||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
scheduler_params = {}
|
||||
for k in cfg.hyperparameters.scheduler_params:
|
||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
||||
@cached_property
|
||||
def ds_cfg(self):
|
||||
optimizer_params = cfg.hyperparameters.optimizer_params
|
||||
|
||||
if 'lr' not in optimizer_params:
|
||||
optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
|
||||
|
||||
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
|
||||
scheduler_params = cfg.hyperparameters.scheduler_params
|
||||
if 'warmup_num_steps' not in scheduler_params:
|
||||
scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
|
||||
|
||||
if 'total_num_steps' not in scheduler_params:
|
||||
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
||||
|
||||
autotune_params = cfg.hyperparameters.autotune_params
|
||||
|
||||
if "enabled" not in autotune_params:
|
||||
autotune_params['enabled'] = True
|
||||
|
||||
if "results_dir" not in autotune_params:
|
||||
autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" )
|
||||
|
||||
if "exps_dir" not in autotune_params:
|
||||
autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" )
|
||||
|
||||
# DeepSpeed fp16 is incompatible with its AMP
|
||||
if cfg.trainer.weight_dtype.lower() == "float16":
|
||||
self.amp = False
|
||||
|
||||
# disable local AMP
|
||||
if self.amp:
|
||||
cfg.trainer.amp = False
|
||||
|
||||
ds_cfg = {
|
||||
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
|
||||
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
|
||||
"optimizer": {
|
||||
"type": cfg.hyperparameters.optimizer,
|
||||
"params": {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
},
|
||||
"params": optimizer_params,
|
||||
} if not cfg.hyperparameters.torch_optimizer else None,
|
||||
"scheduler": {
|
||||
"type": cfg.hyperparameters.scheduler_type,
|
||||
"type": cfg.hyperparameters.scheduler,
|
||||
"params": scheduler_params,
|
||||
} if cfg.hyperparameters.scheduler_type != "" else None,
|
||||
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"auto_cast": True,
|
||||
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
|
||||
"bf16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
||||
"auto_cast": True, # ???
|
||||
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
||||
},
|
||||
"amp": {
|
||||
"enabled": self.amp,
|
||||
},
|
||||
"autotuning": autotune_params if cfg.hyperparameters.autotune else None,
|
||||
"compression_training": {
|
||||
"weight_quantization": {
|
||||
"shared_parameters":{
|
||||
|
@ -214,7 +324,7 @@ class DeepSpeed:
|
|||
"quantize_verbose": True,
|
||||
"quantization_type": "symmetric",
|
||||
"rounding": "nearest",
|
||||
"quantize_weight_in_forward": True,
|
||||
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
||||
"fp16_mixed_quantize":{
|
||||
"enabled": False,
|
||||
"quantize_change_ratio": 1
|
||||
|
@ -223,30 +333,38 @@ class DeepSpeed:
|
|||
"different_groups": {
|
||||
"wq1": {
|
||||
"params": {
|
||||
"start_bits": bits,
|
||||
"target_bits": bits,
|
||||
"start_bits": self.compression_bits,
|
||||
"target_bits": self.compression_bits,
|
||||
"quantization_period": 0
|
||||
},
|
||||
"modules": weights
|
||||
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
|
||||
}
|
||||
}
|
||||
},
|
||||
"activation_quantization": {
|
||||
"shared_parameters":{
|
||||
"enabled": True,
|
||||
"quantizer_kernel": True,
|
||||
"schedule_offset": 0,
|
||||
"quantize_groups": 64,
|
||||
"quantize_verbose": True,
|
||||
"quantization_type": "symmetric",
|
||||
"range_calibration": "dynamic",
|
||||
"schedule_offset": 0
|
||||
"rounding": "nearest",
|
||||
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
||||
"fp16_mixed_quantize":{
|
||||
"enabled": False,
|
||||
"quantize_change_ratio": 1
|
||||
}
|
||||
},
|
||||
"different_groups": {
|
||||
"aq1": {
|
||||
"params": {
|
||||
"bits": bits
|
||||
"bits": self.compression_bits,
|
||||
},
|
||||
"modules": weights
|
||||
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
} if self.use_compression_training else None,
|
||||
"zero_optimization": {
|
||||
"stage": self.zero_optimization_level,
|
||||
|
@ -264,7 +382,10 @@ class DeepSpeed:
|
|||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": True
|
||||
}
|
||||
},
|
||||
"zero_quantized_weights": self.use_compression_training,
|
||||
"zero_hpz_partition_size": world_size(),
|
||||
"zero_quantized_gradients": self.use_compression_training,
|
||||
} if self.zero_optimization_level > 0 else None,
|
||||
"comms_logger": {
|
||||
"enabled": False
|
||||
|
@ -275,113 +396,314 @@ class DeepSpeed:
|
|||
for k in null_keys:
|
||||
del ds_cfg[k]
|
||||
|
||||
if os.path.exists("./config/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
|
||||
if os.path.exists("./data/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
||||
else:
|
||||
ds_cfg.update(self.config)
|
||||
|
||||
return ds_cfg
|
||||
|
||||
@dataclass()
|
||||
class Trainer:
|
||||
iterations: int = 100_000
|
||||
iterations: int = 1_000_000 # maximum iterations to train
|
||||
|
||||
save_tag: str = "step"
|
||||
load_tag: str | None = None
|
||||
save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
|
||||
load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
|
||||
|
||||
save_on_oom: bool = True
|
||||
save_on_quit: bool = True
|
||||
save_frequency: int = 100
|
||||
save_on_oom: bool = True # save if an OOM error is raised
|
||||
save_on_quit: bool = True # save when quitting training
|
||||
|
||||
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
|
||||
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
|
||||
|
||||
save_frequency: int = 100 # frequency to save every X iterations
|
||||
|
||||
load_state_dict: bool = False
|
||||
load_states: bool = True
|
||||
strict_loading: bool = True
|
||||
restart_step_count: bool = False
|
||||
keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
|
||||
|
||||
aggressive_optimizations: bool = False
|
||||
check_for_oom: bool = True
|
||||
load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
|
||||
load_states: bool = True #
|
||||
strict_loading: bool = False # sets strict_loading=True when loading the state dict
|
||||
load_module_only: bool = False #
|
||||
restart_step_count: bool = False # clears the training stats when loading a checkpoint
|
||||
resize_modules: bool = False # automatically resizes
|
||||
|
||||
gc_mode: str | None = None
|
||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
||||
|
||||
weight_dtype: str = "float16"
|
||||
aggressive_optimizations: bool = False # deprecated
|
||||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
|
||||
|
||||
backend: str = "deepspeed"
|
||||
weight_dtype: str = "float16" # dtype to have the model under
|
||||
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
amp: bool = False # automatic mixed precision
|
||||
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
||||
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
|
||||
|
||||
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
|
||||
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
|
||||
|
||||
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if cfg.trainer.weight_dtype == "bfloat16":
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if self.weight_dtype == "float8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
if self.weight_dtype == "float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
@cached_property
|
||||
def scale_loss(self):
|
||||
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
||||
return self.dtype == torch.float16
|
||||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
use_vocos: bool = True # artifact from the VALL-E trainer
|
||||
backend: str = "local" # backend to use when inferencing
|
||||
weight_dtype: str = "float32" # dtype to load the model under
|
||||
amp: bool = False # automatic mixed precision during inferencing
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if self.weight_dtype == "int8":
|
||||
return torch.int8
|
||||
if self.weight_dtype == "float8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
if self.weight_dtype == "float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
@dataclass()
|
||||
class BitsAndBytes:
|
||||
enabled: bool = False
|
||||
injects: bool = False
|
||||
class Optimizations:
|
||||
injects: bool = False # overwrites default torch classes (not recommended)
|
||||
replace: bool = False # replaces modules in place with the optimized version (recommended)
|
||||
compile: bool | str = False # runs torch.compile on the model
|
||||
|
||||
linear: bool = False
|
||||
embedding: bool = False
|
||||
linear: bool = True # inject/replace linear for BnB
|
||||
embedding: bool = True # inject/replace embedding for BnB
|
||||
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
||||
|
||||
bitsandbytes: bool = False # use bitsandbytes
|
||||
dadaptation: bool = False # use dadaptation optimizer
|
||||
bitnet: bool = False # use bitnet
|
||||
fp8: bool = False # use fp8
|
||||
|
||||
model_offloading: dict | None = None # automatically splits the model over a list of devices
|
||||
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
|
||||
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
||||
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
|
||||
|
||||
tensorrt: bool = False
|
||||
|
||||
@dataclass()
|
||||
class Config(_Config):
|
||||
device: str = "cuda"
|
||||
class Config(BaseConfig):
|
||||
device: str = "cuda" # target device
|
||||
mode: str = "training" # "inferencing"
|
||||
experimental: bool = False # Debug flag, unused now
|
||||
|
||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||
models: Models = field(default_factory=lambda: Models)
|
||||
models: dict | list | None = field(default_factory=lambda: [])
|
||||
loras: dict | list | None = field(default_factory=lambda: [])
|
||||
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
||||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||
inference: Inference = field(default_factory=lambda: Inference)
|
||||
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
||||
bitsandbytes: dict | list | None = None # deprecated
|
||||
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||
|
||||
def get_device(self):
|
||||
return torch.cuda.current_device() if self.device == "cuda" else self.device
|
||||
tokenizer: str | None = None # tokenizer class
|
||||
tokenizer_path: str = "./tokenizer.json" # tokenizer path
|
||||
|
||||
weights_format: str = "pth" # "pth" | "sft"
|
||||
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return ".cache" / self.relpath
|
||||
def model(self):
|
||||
for i, model in enumerate(self.models):
|
||||
if model.training:
|
||||
return model
|
||||
|
||||
return self.models[0] if len(self.models) > 0 else None
|
||||
|
||||
# should be renamed to adapters
|
||||
@property
|
||||
def lora(self):
|
||||
for i, lora in enumerate(self.loras):
|
||||
if lora.training:
|
||||
return lora
|
||||
|
||||
return self.loras[0] if len(self.loras) > 0 else None
|
||||
|
||||
@property
|
||||
def distributed(self):
|
||||
return world_size() > 1
|
||||
|
||||
@cached_property
|
||||
def diskcache(self):
|
||||
if self.dataset.cache:
|
||||
if self.yaml_path is not None and self.dataset.cache:
|
||||
return diskcache.Cache(self.cache_dir).memoize
|
||||
return lambda: lambda x: x
|
||||
|
||||
# I don't remember why this is needed
|
||||
def load_yaml( self, config_path ):
|
||||
tmp = Config.from_yaml( config_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_hdf5( self, write=False ):
|
||||
if hasattr(self, 'hdf5'):
|
||||
self.hdf5.close()
|
||||
|
||||
if self.distributed:
|
||||
self.dataset.hdf5_flag = "r"
|
||||
try:
|
||||
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||
except Exception as e:
|
||||
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
# to-do: prune unused keys
|
||||
def format( self, training=True ):
|
||||
if isinstance(self.dataset, type):
|
||||
self.dataset = dict()
|
||||
|
||||
if isinstance(self.models, type):
|
||||
self.models = dict()
|
||||
|
||||
if isinstance(self.loras, type):
|
||||
self.loras = dict()
|
||||
|
||||
if isinstance(self.hyperparameters, type):
|
||||
self.hyperparameters = dict()
|
||||
|
||||
if isinstance(self.evaluation, type):
|
||||
self.evaluation = dict()
|
||||
|
||||
if isinstance(self.trainer, type):
|
||||
self.trainer = dict()
|
||||
|
||||
if isinstance(self.inference, type):
|
||||
self.inference = dict()
|
||||
|
||||
if isinstance(self.optimizations, type):
|
||||
self.optimizations = dict()
|
||||
|
||||
self.dataset = Dataset(**self.dataset)
|
||||
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||
|
||||
self.models = [ Model(**model) for model in self.models ]
|
||||
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
||||
|
||||
if not self.models:
|
||||
self.models = [ Model() ]
|
||||
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
|
||||
self.trainer = Trainer(**self.trainer)
|
||||
|
||||
if not isinstance(self.trainer.deepspeed, type):
|
||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||
|
||||
self.inference = Inference(**self.inference)
|
||||
|
||||
if self.bitsandbytes is not None:
|
||||
self.optimizations = Optimizations(**self.bitsandbytes)
|
||||
else:
|
||||
self.optimizations = Optimizations(**self.optimizations)
|
||||
|
||||
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
||||
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
||||
self.hyperparameters.scheduler_type = ""
|
||||
|
||||
# do not combine the two
|
||||
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
|
||||
self.hyperparameters.scheduler = ""
|
||||
|
||||
if self.hyperparameters.scheduler == "":
|
||||
self.hyperparameters.torch_scheduler = True
|
||||
|
||||
if self.trainer.backend == "local" and self.distributed:
|
||||
self.trainer.ddp = True
|
||||
|
||||
if self.trainer.activation_checkpointing is not None:
|
||||
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
|
||||
|
||||
if not training:
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
# load our HDF5 file if requested here
|
||||
if self.dataset.use_hdf5:
|
||||
self.load_hdf5()
|
||||
|
||||
# load tokenizer
|
||||
if cfg.tokenizer == "naive":
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
else:
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
|
||||
if tokenizer_path and not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||
|
||||
if tokenizer_path and tokenizer_path.exists():
|
||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||
else:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
except Exception as e:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
|
||||
pass
|
||||
|
||||
|
||||
# Preserves the old behavior
|
||||
class NaiveTokenizer:
|
||||
def get_vocab( self ):
|
||||
"""
|
||||
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
||||
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
||||
"""
|
||||
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
|
||||
|
||||
@cached_property
|
||||
def _bos_token( self ):
|
||||
return self.get_vocab()["<s>"]
|
||||
|
||||
@cached_property
|
||||
def _eos_token( self ):
|
||||
return self.get_vocab()["</s>"]
|
||||
|
||||
def encode( self, s ):
|
||||
symmap = self.get_vocab()
|
||||
s = s.replace("O", "0")
|
||||
s = [f"<s>"] + [ p if p in symmap else " " for p in s ] + [f"</s>"]
|
||||
return [*map(symmap.get, s)]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
cfg = Config.from_cli()
|
||||
|
||||
# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves
|
||||
cfg.dataset = Dataset(**cfg.dataset)
|
||||
cfg.models = Models(**cfg.models)
|
||||
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
|
||||
cfg.evaluation = Evaluation(**cfg.evaluation)
|
||||
cfg.trainer = Trainer(**cfg.trainer)
|
||||
cfg.inference = Inference(**cfg.inference)
|
||||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
||||
|
||||
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
|
||||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
try:
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
||||
if not cfg.dataset.use_hdf5:
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
|
||||
try:
|
||||
cfg.format()
|
||||
except Exception as e:
|
||||
_logger.error(f"Error while parsing config YAML: {str(e)}")
|
||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(cfg)
|
||||
print(cfg)
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
# todo: clean this mess up
|
||||
|
||||
import copy
|
||||
# import h5py
|
||||
import h5py
|
||||
import json
|
||||
import logging
|
||||
#import numpy as np
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import math
|
||||
import itertools
|
||||
|
||||
from .config import cfg
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
from .utils.io import torch_save, torch_load
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -20,23 +23,57 @@ from typing import Any
|
|||
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset as _Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@cache
|
||||
# to-do: clean up this symmap mess
|
||||
def get_symmap():
|
||||
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
|
||||
return cfg.tokenizer.get_vocab()
|
||||
|
||||
@cache
|
||||
def _get_symbols( content ):
|
||||
content = content.replace("O", "0")
|
||||
return [f"<s>"] + [ p for p in content ] + [f"</s>"]
|
||||
def tokenize( s ):
|
||||
if isinstance( s, list ):
|
||||
s = "".join( s )
|
||||
return cfg.tokenizer.encode( s )
|
||||
|
||||
"""
|
||||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def _get_hdf5_path(path):
|
||||
# to-do: better validation
|
||||
return str(path)
|
||||
|
||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
data_dir = str(data_dir)
|
||||
|
||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||
|
||||
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items()] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, validate=False ):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
groups = defaultdict(list)
|
||||
for e in l:
|
||||
groups[fn(e)].append(e)
|
||||
groups = {k: groups[k] for k in sorted(groups)}
|
||||
for interleaved in zip_longest(*groups.values()):
|
||||
for value in interleaved:
|
||||
if value is not None:
|
||||
yield value
|
||||
"""
|
||||
|
||||
class Dataset(_Dataset):
|
||||
def __init__(
|
||||
|
@ -44,43 +81,90 @@ class Dataset(_Dataset):
|
|||
paths,
|
||||
width=300,
|
||||
height=80,
|
||||
stacks=0,
|
||||
|
||||
symmap=get_symmap(),
|
||||
training=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._head = None
|
||||
|
||||
self.paths = paths
|
||||
self.sampler = None
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
self.stacks = stacks
|
||||
|
||||
self.paths = paths
|
||||
self.image_dtype = cfg.trainer.dtype
|
||||
self.symmap = symmap
|
||||
|
||||
self.training = training
|
||||
self.dataset_type = "training" if self.training else "validation"
|
||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
|
||||
transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
@cached_property
|
||||
def symbols(self):
|
||||
return sorted(set().union(*[_get_symbols(path.stem) for path in self.paths]))
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
if len(self.dataset) == 0:
|
||||
self.dataset = cfg.dataset.training
|
||||
|
||||
# split dataset accordingly per GPU
|
||||
if cfg.distributed and self.training:
|
||||
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||
|
||||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
@cached_property
|
||||
def sampler_state_dict_path(self):
|
||||
return cfg.rel_path / f"sampler.rank{global_rank()}.pt"
|
||||
|
||||
def save_state_dict(self, path = None):
|
||||
"""
|
||||
if path is None:
|
||||
path = self.sampler_state_dict_path
|
||||
if self.sampler is not None:
|
||||
state_dict = self.sampler.get_state()
|
||||
elif self.samplers is not None:
|
||||
state_dict = {
|
||||
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
||||
}
|
||||
torch_save(state_dict, path)
|
||||
"""
|
||||
return
|
||||
|
||||
def load_state_dict(self, path = None):
|
||||
"""
|
||||
if path is None:
|
||||
path = self.sampler_state_dict_path
|
||||
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
state_dict = torch_load(path)
|
||||
if self.sampler is not None:
|
||||
state_dict = self.sampler.set_state(state_dict)
|
||||
else:
|
||||
for name, sampler in state_dict["samplers"].items():
|
||||
if name not in self.samplers:
|
||||
continue
|
||||
self.samplers[name].set_state( sampler )
|
||||
"""
|
||||
return
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
tokens = tokenize( path.stem.upper() )
|
||||
text = torch.tensor( tokens ).to(dtype=torch.uint8)
|
||||
|
||||
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
|
||||
try:
|
||||
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
|
||||
except Exception as e:
|
||||
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
|
||||
raise e
|
||||
image = Image.open(path).convert('RGB')
|
||||
width, height = image.size
|
||||
|
||||
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
|
||||
image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
|
@ -98,11 +182,6 @@ class Dataset(_Dataset):
|
|||
def __len__(self):
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
||||
def pin_memory(self):
|
||||
self.text = self.text.pin_memory()
|
||||
self.image = self.image.pin_memory()
|
||||
return self
|
||||
|
||||
|
||||
def collate_fn(samples: list[dict]):
|
||||
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
||||
|
@ -111,21 +190,28 @@ def collate_fn(samples: list[dict]):
|
|||
|
||||
def _seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
#np.random.seed(worker_seed)
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def _create_dataloader(dataset, training):
|
||||
kwargs = dict(
|
||||
shuffle=True,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
drop_last=training,
|
||||
sampler=dataset.sampler if training else None,
|
||||
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
||||
batch_sampler=dataset.sampler,
|
||||
)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
shuffle=True, # training
|
||||
drop_last=training,
|
||||
num_workers=cfg.dataset.workers,
|
||||
collate_fn=collate_fn,
|
||||
persistent_workers=cfg.dataset.workers > 0,
|
||||
pin_memory=False, # True,
|
||||
persistent_workers=cfg.dataset.workers > 1,
|
||||
pin_memory=False,
|
||||
worker_init_fn=_seed_worker,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _load_train_val_paths( val_ratio=0.1 ):
|
||||
|
@ -133,8 +219,8 @@ def _load_train_val_paths( val_ratio=0.1 ):
|
|||
train_paths = []
|
||||
val_paths = []
|
||||
|
||||
print(cfg.dataset.training)
|
||||
for data_dir in cfg.dataset.training:
|
||||
paths.extend(data_dir.rglob("*.jpg"))
|
||||
paths.extend(data_dir.rglob("*.png"))
|
||||
|
||||
if len(paths) > 0:
|
||||
|
@ -146,12 +232,13 @@ def _load_train_val_paths( val_ratio=0.1 ):
|
|||
val_len = math.floor(len(train_paths) * val_ratio)
|
||||
train_len = math.floor(len(train_paths) * (1 - val_ratio))
|
||||
|
||||
print(val_len, train_len)
|
||||
|
||||
val_paths = train_paths[:-val_len]
|
||||
train_paths = train_paths[:train_len]
|
||||
else:
|
||||
paths = []
|
||||
|
||||
for data_dir in cfg.dataset.validation:
|
||||
paths.extend(data_dir.rglob("*.jpg"))
|
||||
paths.extend(data_dir.rglob("*.png"))
|
||||
|
||||
if len(paths) > 0:
|
||||
|
@ -169,7 +256,6 @@ def _load_train_val_paths( val_ratio=0.1 ):
|
|||
|
||||
return train_paths, val_paths
|
||||
|
||||
@cfg.diskcache()
|
||||
def create_datasets():
|
||||
train_paths, val_paths = _load_train_val_paths()
|
||||
|
||||
|
@ -187,10 +273,10 @@ def create_datasets():
|
|||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
|
||||
# deepcopy is slow
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
subtrain_dataset.head_(cfg.evaluation.size)
|
||||
subtrain_dataset.training_(False)
|
||||
|
@ -200,8 +286,6 @@ def create_train_val_dataloader():
|
|||
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
|
||||
|
||||
_logger.info(str(train_dataset.symmap))
|
||||
|
||||
|
||||
|