diff --git a/.gitignore b/.gitignore index b6dd66a..42edb8b 100755 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,6 @@ __pycache__ /.cache /config /*.egg-info -/vall_e/version.py +/image_classifier/version.py /build /.cache \ No newline at end of file diff --git a/README.md b/README.md index 8e9647c..347e629 100755 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ This is a simple ResNet based image classifier for """specific images""", using ## Premise -This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born. +This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born. This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`. @@ -22,13 +22,17 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc ## Inferencing -Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0` +Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0` + +### Continuous Usage + +If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label. ## Known Issues * Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed. -* The evaluation / validation routine doesn't quite work. * Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed. +* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment. ## Strawmen diff --git a/data/config.yaml b/data/config.yaml index 86ee1bb..5e51278 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,6 +1,6 @@ dataset: training: [ - "./data/captchas/" + "./data/images/" ] validation: [] @@ -12,17 +12,17 @@ dataset: models: _models: - - name: "captcha" + - name: "classifier" tokens: 0 len: 6 hyperparameters: batch_size: 256 - gradient_accumulation_steps: 5 + gradient_accumulation_steps: 64 gradient_clipping: 100 optimizer: Adamw - learning_rate: 5.0e-5 + learning_rate: 1.0e-3 scheduler_type: "" #scheduler_type: OneCycle diff --git a/image_classifier/__main__.py b/image_classifier/__main__.py index 6f7019b..d9907fe 100755 --- a/image_classifier/__main__.py +++ b/image_classifier/__main__.py @@ -3,13 +3,13 @@ import base64 from io import BytesIO from pathlib import Path -from .inference import CAPTCHA +from .inference import Classifier from PIL import Image from simple_http_server import route, server def main(): - parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) + parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--listen", action='store_true') parser.add_argument("--port", type=int, default=9090) parser.add_argument("--yaml", type=Path, default=None) @@ -18,15 +18,15 @@ def main(): parser.add_argument("--device", default="cuda") args, unknown = parser.parse_known_args() - captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device ) + classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device ) if args.listen: @route("/") def inference( b64, temperature=1.0 ): image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") - return { "answer": captcha.inference( image=image, temperature=args.temp ) } + return { "answer": classifier.inference( image=image, temperature=args.temp ) } server.start(port=args.port) else: - parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) + parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--path", type=Path) parser.add_argument("--base64", type=str) parser.add_argument("--temp", type=float, default=1.0) @@ -39,7 +39,7 @@ def main(): else: raise "Specify a --path or --base64." - answer = captcha.inference( image=image, temperature=args.temp ) + answer = classifier.inference( image=image, temperature=args.temp ) print("Answer:", answer) if __name__ == "__main__": diff --git a/image_classifier/config.py b/image_classifier/config.py index 1c94ac8..2f4cfce 100755 --- a/image_classifier/config.py +++ b/image_classifier/config.py @@ -111,6 +111,7 @@ class Dataset: temp: list[Path] = field(default_factory=lambda: []) + # de-implemented, because the data isn't that large to facilitate HDF5 hdf5_name: str = "data.h5" use_hdf5: bool = False @@ -149,12 +150,12 @@ class Models: class Hyperparameters: batch_size: int = 8 gradient_accumulation_steps: int = 32 - gradient_clipping: int = 100 + gradient_clipping: int = 100 # to be implemented in the local backend optimizer: str = "Adamw" learning_rate: float = 3.25e-4 - scheduler_type: str = "" + scheduler_type: str = "" # to be implemented in the local backend scheduler_params: dict = field(default_factory=lambda: {}) @dataclass() @@ -317,7 +318,7 @@ class Trainer: @dataclass() class Inference: - use_vocos: bool = True + use_vocos: bool = True # artifact from the VALL-E trainer @dataclass() class BitsAndBytes: diff --git a/image_classifier/data.py b/image_classifier/data.py index 02d6df1..23004f1 100755 --- a/image_classifier/data.py +++ b/image_classifier/data.py @@ -1,7 +1,7 @@ # todo: clean this mess up import copy -import h5py +# import h5py import json import logging import numpy as np @@ -60,7 +60,7 @@ class Dataset(_Dataset): self.training = training self.transform = transforms.Compose([ - transforms.Resize((self.height, self.width)), + #transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) @@ -73,13 +73,14 @@ class Dataset(_Dataset): def __getitem__(self, index): path = self.paths[index] + # stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here try: text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8) except Exception as e: print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem) raise e - image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) + image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB return dict( index=index, @@ -192,7 +193,7 @@ def create_train_val_dataloader(): subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset.head_(cfg.evaluation.size) - #subtrain_dataset.training_(False) + subtrain_dataset.training_(False) train_dl = _create_dataloader(train_dataset, training=True) val_dl = _create_dataloader(val_dataset, training=False) @@ -209,9 +210,11 @@ def create_train_val_dataloader(): return train_dl, subtrain_dl, val_dl +""" if __name__ == "__main__": create_dataset_hdf5() train_dl, subtrain_dl, val_dl = create_train_val_dataloader() sample = train_dl.dataset[0] print(sample) +""" diff --git a/image_classifier/engines/base.py b/image_classifier/engines/base.py index 8b9dc04..db246bd 100755 --- a/image_classifier/engines/base.py +++ b/image_classifier/engines/base.py @@ -48,6 +48,7 @@ if not distributed_initialized() and cfg.trainer.backend == "local": init_distributed(torch.distributed.init_process_group) # A very naive engine implementation using barebones PyTorch +# to-do: implement lr_sheduling class Engine(): def __init__(self, *args, **kwargs): self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype) diff --git a/image_classifier/inference.py b/image_classifier/inference.py index 43e4824..07949f4 100755 --- a/image_classifier/inference.py +++ b/image_classifier/inference.py @@ -6,7 +6,7 @@ from .config import cfg from .export import load_models from .data import get_symmap, _get_symbols -class CAPTCHA(): +class Classifier(): def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ): self.loading = True self.device = device @@ -48,6 +48,6 @@ class CAPTCHA(): image = self.transform(image).to(self.device) answer = self.model( image=[image], sampling_temperature=temperature ) - answer = answer[0].replace('', "").replace("", "") + answer = answer[0].replace('', "").replace("", "") # it would be better to just slice between these, but I can't be assed return answer \ No newline at end of file diff --git a/image_classifier/models/base.py b/image_classifier/models/base.py index 4c7dd2d..2f884c6 100755 --- a/image_classifier/models/base.py +++ b/image_classifier/models/base.py @@ -16,28 +16,6 @@ from torchvision.models import resnet18 from ..data import get_symmap -def _create_mask(l, device): - """1 is valid region and 0 is invalid.""" - seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) - stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) - return (seq < stop).float() # (b t) - -def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): - """ - Args: - x_list: [(t d)] - Returns: - x: (? ? ?) - m: (? ? ?), same as x - """ - l = list(map(len, x_list)) - x = rearrange(pad_sequence(x_list), pattern) - m = _create_mask(l, x_list[0].device) - m = m.t().unsqueeze(-1) # (t b 1) - m = rearrange(m, pattern) - m = m.to(x) - return x, m - class Model(nn.Module): def __init__( self, @@ -61,9 +39,20 @@ class Model(nn.Module): self.resnet = resnet18(pretrained=False) self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len ) - - self.criterion = nn.CTCLoss(zero_infinity=True) + self.accuracy_metric = MulticlassAccuracy( + n_tokens, + #top_k=10, + average="micro", + multidim_average="global", + ) + + self.precision_metric = MulticlassPrecision( + n_tokens, + #top_k=10, + average="micro", + multidim_average="global", + ) def forward( self, @@ -77,6 +66,7 @@ class Model(nn.Module): x = self.resnet( x_list ) y = x.view(x.size(0), self.n_len, self.n_tokens) + # either of these should do, but my VALL-E forward pass uses this, so might as well keep to it # pred = y.argmax(dim=2) pred = Categorical(logits=y / sampling_temperature).sample() @@ -84,87 +74,20 @@ class Model(nn.Module): if text is not None: y_list = rearrange(pad_sequence(text), "t b -> b t") - + loss = 0 for i in range(self.n_len): + if i >= y_list.shape[1]: + break loss += F.cross_entropy( y[:, i], y_list[:, i] ) self.loss = dict( nll=loss ) - return answer + self.stats = dict( + acc = self.accuracy_metric( pred, y_list ), + precision = self.precision_metric( pred, y_list ), + ) -def example_usage(): - from ..config import cfg - cfg.trainer.backend = "local" - cfg.trainer.check_for_oom = False - - from functools import partial - - from einops import repeat - - from ..emb.qnt import decode_to_file - from ..engines import Engine, Engines - from tqdm import tqdm, trange - - from .ar import AR - from .nar import NAR - - device = "cpu" - x8 = partial(repeat, pattern="t -> t l", l=2) - symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} - def tokenize(content, lang_marker="en"): - split = content.split(" ") - phones = [f""] + [ " " if not p else p for p in split ] + [f""] - return torch.tensor([*map(symmap.get, phones)]).to() - - kwargs = { - 'n_tokens': 1024, - 'd_model': 1024, - 'n_heads': 16, - 'n_layers': 12, - } - models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) } - engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() }) - - train = True - - - - - def sample( name, steps=400 ): - AR = None - NAR = None - - engines.eval() - for name, engine in engines.items(): - if name[:2] == "ar": - AR = engine - elif name[:3] == "nar": - NAR = engine - - resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0) - resps_list = [r.unsqueeze(-1) for r in resps_list] - codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) - - decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device) - decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device) - - if train: - sample("init", 15) - - engines.train() - t = trange(60) - for i in t: - stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu") - t.set_description(f"{stats}") - else: - for name, engine in engines.items(): - engine.module.load_state_dict(torch.load(f"./data/{name}.pth")) - - sample("final") - - -if __name__ == "__main__": - example_usage() + return answer \ No newline at end of file diff --git a/image_classifier/train.py b/image_classifier/train.py index 075014c..70de575 100755 --- a/image_classifier/train.py +++ b/image_classifier/train.py @@ -23,11 +23,13 @@ def train_feeder(engine, batch): engine( image=batch["image"], text=batch["text"] ) losses = engine.gather_attribute("loss") + stat = engine.gather_attribute("stats") loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} + stats |= {k: v.item() for k, v in stat.items()} return loss, stats @@ -55,7 +57,6 @@ def run_eval(engines, eval_name, dl): for batch in tqdm(dl): batch: dict = to_device(batch, cfg.device) - # if we're training both models, provide output for both res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature ) for path, ref, hyp in zip(batch["path"], batch["text"], res): diff --git a/image_classifier/version.py b/image_classifier/version.py deleted file mode 100755 index 2e57d81..0000000 --- a/image_classifier/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.1-dev20230804142130"