diff --git a/.gitignore b/.gitignore
index b6dd66a..42edb8b 100755
--- a/.gitignore
+++ b/.gitignore
@@ -5,6 +5,6 @@ __pycache__
/.cache
/config
/*.egg-info
-/vall_e/version.py
+/image_classifier/version.py
/build
/.cache
\ No newline at end of file
diff --git a/README.md b/README.md
index 8e9647c..347e629 100755
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@ This is a simple ResNet based image classifier for """specific images""", using
## Premise
-This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
+This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
@@ -22,13 +22,17 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
## Inferencing
-Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
+Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
+
+### Continuous Usage
+
+If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
## Known Issues
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
-* The evaluation / validation routine doesn't quite work.
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
+* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
## Strawmen
diff --git a/data/config.yaml b/data/config.yaml
index 86ee1bb..5e51278 100755
--- a/data/config.yaml
+++ b/data/config.yaml
@@ -1,6 +1,6 @@
dataset:
training: [
- "./data/captchas/"
+ "./data/images/"
]
validation: []
@@ -12,17 +12,17 @@ dataset:
models:
_models:
- - name: "captcha"
+ - name: "classifier"
tokens: 0
len: 6
hyperparameters:
batch_size: 256
- gradient_accumulation_steps: 5
+ gradient_accumulation_steps: 64
gradient_clipping: 100
optimizer: Adamw
- learning_rate: 5.0e-5
+ learning_rate: 1.0e-3
scheduler_type: ""
#scheduler_type: OneCycle
diff --git a/image_classifier/__main__.py b/image_classifier/__main__.py
index 6f7019b..d9907fe 100755
--- a/image_classifier/__main__.py
+++ b/image_classifier/__main__.py
@@ -3,13 +3,13 @@ import base64
from io import BytesIO
from pathlib import Path
-from .inference import CAPTCHA
+from .inference import Classifier
from PIL import Image
from simple_http_server import route, server
def main():
- parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
+ parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", action='store_true')
parser.add_argument("--port", type=int, default=9090)
parser.add_argument("--yaml", type=Path, default=None)
@@ -18,15 +18,15 @@ def main():
parser.add_argument("--device", default="cuda")
args, unknown = parser.parse_known_args()
- captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device )
+ classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
if args.listen:
@route("/")
def inference( b64, temperature=1.0 ):
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
- return { "answer": captcha.inference( image=image, temperature=args.temp ) }
+ return { "answer": classifier.inference( image=image, temperature=args.temp ) }
server.start(port=args.port)
else:
- parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
+ parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--path", type=Path)
parser.add_argument("--base64", type=str)
parser.add_argument("--temp", type=float, default=1.0)
@@ -39,7 +39,7 @@ def main():
else:
raise "Specify a --path or --base64."
- answer = captcha.inference( image=image, temperature=args.temp )
+ answer = classifier.inference( image=image, temperature=args.temp )
print("Answer:", answer)
if __name__ == "__main__":
diff --git a/image_classifier/config.py b/image_classifier/config.py
index 1c94ac8..2f4cfce 100755
--- a/image_classifier/config.py
+++ b/image_classifier/config.py
@@ -111,6 +111,7 @@ class Dataset:
temp: list[Path] = field(default_factory=lambda: [])
+ # de-implemented, because the data isn't that large to facilitate HDF5
hdf5_name: str = "data.h5"
use_hdf5: bool = False
@@ -149,12 +150,12 @@ class Models:
class Hyperparameters:
batch_size: int = 8
gradient_accumulation_steps: int = 32
- gradient_clipping: int = 100
+ gradient_clipping: int = 100 # to be implemented in the local backend
optimizer: str = "Adamw"
learning_rate: float = 3.25e-4
- scheduler_type: str = ""
+ scheduler_type: str = "" # to be implemented in the local backend
scheduler_params: dict = field(default_factory=lambda: {})
@dataclass()
@@ -317,7 +318,7 @@ class Trainer:
@dataclass()
class Inference:
- use_vocos: bool = True
+ use_vocos: bool = True # artifact from the VALL-E trainer
@dataclass()
class BitsAndBytes:
diff --git a/image_classifier/data.py b/image_classifier/data.py
index 02d6df1..23004f1 100755
--- a/image_classifier/data.py
+++ b/image_classifier/data.py
@@ -1,7 +1,7 @@
# todo: clean this mess up
import copy
-import h5py
+# import h5py
import json
import logging
import numpy as np
@@ -60,7 +60,7 @@ class Dataset(_Dataset):
self.training = training
self.transform = transforms.Compose([
- transforms.Resize((self.height, self.width)),
+ #transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@@ -73,13 +73,14 @@ class Dataset(_Dataset):
def __getitem__(self, index):
path = self.paths[index]
+ # stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
try:
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
except Exception as e:
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
raise e
- image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype)
+ image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
return dict(
index=index,
@@ -192,7 +193,7 @@ def create_train_val_dataloader():
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
- #subtrain_dataset.training_(False)
+ subtrain_dataset.training_(False)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)
@@ -209,9 +210,11 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
+"""
if __name__ == "__main__":
create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)
+"""
diff --git a/image_classifier/engines/base.py b/image_classifier/engines/base.py
index 8b9dc04..db246bd 100755
--- a/image_classifier/engines/base.py
+++ b/image_classifier/engines/base.py
@@ -48,6 +48,7 @@ if not distributed_initialized() and cfg.trainer.backend == "local":
init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch
+# to-do: implement lr_sheduling
class Engine():
def __init__(self, *args, **kwargs):
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
diff --git a/image_classifier/inference.py b/image_classifier/inference.py
index 43e4824..07949f4 100755
--- a/image_classifier/inference.py
+++ b/image_classifier/inference.py
@@ -6,7 +6,7 @@ from .config import cfg
from .export import load_models
from .data import get_symmap, _get_symbols
-class CAPTCHA():
+class Classifier():
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
self.loading = True
self.device = device
@@ -48,6 +48,6 @@ class CAPTCHA():
image = self.transform(image).to(self.device)
answer = self.model( image=[image], sampling_temperature=temperature )
- answer = answer[0].replace('', "").replace("", "")
+ answer = answer[0].replace('', "").replace("", "") # it would be better to just slice between these, but I can't be assed
return answer
\ No newline at end of file
diff --git a/image_classifier/models/base.py b/image_classifier/models/base.py
index 4c7dd2d..2f884c6 100755
--- a/image_classifier/models/base.py
+++ b/image_classifier/models/base.py
@@ -16,28 +16,6 @@ from torchvision.models import resnet18
from ..data import get_symmap
-def _create_mask(l, device):
- """1 is valid region and 0 is invalid."""
- seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
- stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
- return (seq < stop).float() # (b t)
-
-def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
- """
- Args:
- x_list: [(t d)]
- Returns:
- x: (? ? ?)
- m: (? ? ?), same as x
- """
- l = list(map(len, x_list))
- x = rearrange(pad_sequence(x_list), pattern)
- m = _create_mask(l, x_list[0].device)
- m = m.t().unsqueeze(-1) # (t b 1)
- m = rearrange(m, pattern)
- m = m.to(x)
- return x, m
-
class Model(nn.Module):
def __init__(
self,
@@ -61,9 +39,20 @@ class Model(nn.Module):
self.resnet = resnet18(pretrained=False)
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
-
- self.criterion = nn.CTCLoss(zero_infinity=True)
+ self.accuracy_metric = MulticlassAccuracy(
+ n_tokens,
+ #top_k=10,
+ average="micro",
+ multidim_average="global",
+ )
+
+ self.precision_metric = MulticlassPrecision(
+ n_tokens,
+ #top_k=10,
+ average="micro",
+ multidim_average="global",
+ )
def forward(
self,
@@ -77,6 +66,7 @@ class Model(nn.Module):
x = self.resnet( x_list )
y = x.view(x.size(0), self.n_len, self.n_tokens)
+ # either of these should do, but my VALL-E forward pass uses this, so might as well keep to it
# pred = y.argmax(dim=2)
pred = Categorical(logits=y / sampling_temperature).sample()
@@ -84,87 +74,20 @@ class Model(nn.Module):
if text is not None:
y_list = rearrange(pad_sequence(text), "t b -> b t")
-
+
loss = 0
for i in range(self.n_len):
+ if i >= y_list.shape[1]:
+ break
loss += F.cross_entropy( y[:, i], y_list[:, i] )
self.loss = dict(
nll=loss
)
- return answer
+ self.stats = dict(
+ acc = self.accuracy_metric( pred, y_list ),
+ precision = self.precision_metric( pred, y_list ),
+ )
-def example_usage():
- from ..config import cfg
- cfg.trainer.backend = "local"
- cfg.trainer.check_for_oom = False
-
- from functools import partial
-
- from einops import repeat
-
- from ..emb.qnt import decode_to_file
- from ..engines import Engine, Engines
- from tqdm import tqdm, trange
-
- from .ar import AR
- from .nar import NAR
-
- device = "cpu"
- x8 = partial(repeat, pattern="t -> t l", l=2)
- symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
- def tokenize(content, lang_marker="en"):
- split = content.split(" ")
- phones = [f""] + [ " " if not p else p for p in split ] + [f""]
- return torch.tensor([*map(symmap.get, phones)]).to()
-
- kwargs = {
- 'n_tokens': 1024,
- 'd_model': 1024,
- 'n_heads': 16,
- 'n_layers': 12,
- }
- models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
- engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
-
- train = True
-
-
-
-
- def sample( name, steps=400 ):
- AR = None
- NAR = None
-
- engines.eval()
- for name, engine in engines.items():
- if name[:2] == "ar":
- AR = engine
- elif name[:3] == "nar":
- NAR = engine
-
- resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
- resps_list = [r.unsqueeze(-1) for r in resps_list]
- codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
-
- decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
- decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
-
- if train:
- sample("init", 15)
-
- engines.train()
- t = trange(60)
- for i in t:
- stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
- t.set_description(f"{stats}")
- else:
- for name, engine in engines.items():
- engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
-
- sample("final")
-
-
-if __name__ == "__main__":
- example_usage()
+ return answer
\ No newline at end of file
diff --git a/image_classifier/train.py b/image_classifier/train.py
index 075014c..70de575 100755
--- a/image_classifier/train.py
+++ b/image_classifier/train.py
@@ -23,11 +23,13 @@ def train_feeder(engine, batch):
engine( image=batch["image"], text=batch["text"] )
losses = engine.gather_attribute("loss")
+ stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
+ stats |= {k: v.item() for k, v in stat.items()}
return loss, stats
@@ -55,7 +57,6 @@ def run_eval(engines, eval_name, dl):
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
- # if we're training both models, provide output for both
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
for path, ref, hyp in zip(batch["path"], batch["text"], res):
diff --git a/image_classifier/version.py b/image_classifier/version.py
deleted file mode 100755
index 2e57d81..0000000
--- a/image_classifier/version.py
+++ /dev/null
@@ -1 +0,0 @@
-__version__ = "0.0.1-dev20230804142130"