tweaks, fixes, cleanup, added reporting accuracy/precision from the VALL-E trainer (which indirectly revealed a grody bug in the VALL-E trainer), some other cr*p
This commit is contained in:
parent
77a9625e93
commit
93987ea5d6
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -5,6 +5,6 @@ __pycache__
|
|||
/.cache
|
||||
/config
|
||||
/*.egg-info
|
||||
/vall_e/version.py
|
||||
/image_classifier/version.py
|
||||
/build
|
||||
/.cache
|
10
README.md
10
README.md
|
@ -4,7 +4,7 @@ This is a simple ResNet based image classifier for """specific images""", using
|
|||
|
||||
## Premise
|
||||
|
||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
|
||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
|
||||
|
||||
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
|
||||
|
||||
|
@ -22,13 +22,17 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
|
|||
|
||||
## Inferencing
|
||||
|
||||
Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
||||
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
||||
|
||||
### Continuous Usage
|
||||
|
||||
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
|
||||
|
||||
## Known Issues
|
||||
|
||||
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
|
||||
* The evaluation / validation routine doesn't quite work.
|
||||
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
|
||||
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
|
||||
|
||||
## Strawmen
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
dataset:
|
||||
training: [
|
||||
"./data/captchas/"
|
||||
"./data/images/"
|
||||
]
|
||||
|
||||
validation: []
|
||||
|
@ -12,17 +12,17 @@ dataset:
|
|||
|
||||
models:
|
||||
_models:
|
||||
- name: "captcha"
|
||||
- name: "classifier"
|
||||
tokens: 0
|
||||
len: 6
|
||||
|
||||
hyperparameters:
|
||||
batch_size: 256
|
||||
gradient_accumulation_steps: 5
|
||||
gradient_accumulation_steps: 64
|
||||
gradient_clipping: 100
|
||||
|
||||
optimizer: Adamw
|
||||
learning_rate: 5.0e-5
|
||||
learning_rate: 1.0e-3
|
||||
|
||||
scheduler_type: ""
|
||||
#scheduler_type: OneCycle
|
||||
|
|
|
@ -3,13 +3,13 @@ import base64
|
|||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from .inference import CAPTCHA
|
||||
from .inference import Classifier
|
||||
|
||||
from PIL import Image
|
||||
from simple_http_server import route, server
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--listen", action='store_true')
|
||||
parser.add_argument("--port", type=int, default=9090)
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
|
@ -18,15 +18,15 @@ def main():
|
|||
parser.add_argument("--device", default="cuda")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device )
|
||||
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
|
||||
if args.listen:
|
||||
@route("/")
|
||||
def inference( b64, temperature=1.0 ):
|
||||
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
|
||||
return { "answer": captcha.inference( image=image, temperature=args.temp ) }
|
||||
return { "answer": classifier.inference( image=image, temperature=args.temp ) }
|
||||
server.start(port=args.port)
|
||||
else:
|
||||
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--path", type=Path)
|
||||
parser.add_argument("--base64", type=str)
|
||||
parser.add_argument("--temp", type=float, default=1.0)
|
||||
|
@ -39,7 +39,7 @@ def main():
|
|||
else:
|
||||
raise "Specify a --path or --base64."
|
||||
|
||||
answer = captcha.inference( image=image, temperature=args.temp )
|
||||
answer = classifier.inference( image=image, temperature=args.temp )
|
||||
print("Answer:", answer)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -111,6 +111,7 @@ class Dataset:
|
|||
|
||||
temp: list[Path] = field(default_factory=lambda: [])
|
||||
|
||||
# de-implemented, because the data isn't that large to facilitate HDF5
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
|
||||
|
@ -149,12 +150,12 @@ class Models:
|
|||
class Hyperparameters:
|
||||
batch_size: int = 8
|
||||
gradient_accumulation_steps: int = 32
|
||||
gradient_clipping: int = 100
|
||||
gradient_clipping: int = 100 # to be implemented in the local backend
|
||||
|
||||
optimizer: str = "Adamw"
|
||||
learning_rate: float = 3.25e-4
|
||||
|
||||
scheduler_type: str = ""
|
||||
scheduler_type: str = "" # to be implemented in the local backend
|
||||
scheduler_params: dict = field(default_factory=lambda: {})
|
||||
|
||||
@dataclass()
|
||||
|
@ -317,7 +318,7 @@ class Trainer:
|
|||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
use_vocos: bool = True
|
||||
use_vocos: bool = True # artifact from the VALL-E trainer
|
||||
|
||||
@dataclass()
|
||||
class BitsAndBytes:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# todo: clean this mess up
|
||||
|
||||
import copy
|
||||
import h5py
|
||||
# import h5py
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
|
@ -60,7 +60,7 @@ class Dataset(_Dataset):
|
|||
self.training = training
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((self.height, self.width)),
|
||||
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
@ -73,13 +73,14 @@ class Dataset(_Dataset):
|
|||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
|
||||
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
|
||||
try:
|
||||
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
|
||||
except Exception as e:
|
||||
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
|
||||
raise e
|
||||
|
||||
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype)
|
||||
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
|
@ -192,7 +193,7 @@ def create_train_val_dataloader():
|
|||
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
subtrain_dataset.head_(cfg.evaluation.size)
|
||||
#subtrain_dataset.training_(False)
|
||||
subtrain_dataset.training_(False)
|
||||
|
||||
train_dl = _create_dataloader(train_dataset, training=True)
|
||||
val_dl = _create_dataloader(val_dataset, training=False)
|
||||
|
@ -209,9 +210,11 @@ def create_train_val_dataloader():
|
|||
|
||||
return train_dl, subtrain_dl, val_dl
|
||||
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
create_dataset_hdf5()
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
sample = train_dl.dataset[0]
|
||||
print(sample)
|
||||
"""
|
||||
|
|
|
@ -48,6 +48,7 @@ if not distributed_initialized() and cfg.trainer.backend == "local":
|
|||
init_distributed(torch.distributed.init_process_group)
|
||||
|
||||
# A very naive engine implementation using barebones PyTorch
|
||||
# to-do: implement lr_sheduling
|
||||
class Engine():
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
||||
|
|
|
@ -6,7 +6,7 @@ from .config import cfg
|
|||
from .export import load_models
|
||||
from .data import get_symmap, _get_symbols
|
||||
|
||||
class CAPTCHA():
|
||||
class Classifier():
|
||||
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
|
||||
self.loading = True
|
||||
self.device = device
|
||||
|
@ -48,6 +48,6 @@ class CAPTCHA():
|
|||
image = self.transform(image).to(self.device)
|
||||
|
||||
answer = self.model( image=[image], sampling_temperature=temperature )
|
||||
answer = answer[0].replace('<s>', "").replace("</s>", "")
|
||||
answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
|
||||
|
||||
return answer
|
|
@ -16,28 +16,6 @@ from torchvision.models import resnet18
|
|||
|
||||
from ..data import get_symmap
|
||||
|
||||
def _create_mask(l, device):
|
||||
"""1 is valid region and 0 is invalid."""
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
||||
return (seq < stop).float() # (b t)
|
||||
|
||||
def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
||||
"""
|
||||
Args:
|
||||
x_list: [(t d)]
|
||||
Returns:
|
||||
x: (? ? ?)
|
||||
m: (? ? ?), same as x
|
||||
"""
|
||||
l = list(map(len, x_list))
|
||||
x = rearrange(pad_sequence(x_list), pattern)
|
||||
m = _create_mask(l, x_list[0].device)
|
||||
m = m.t().unsqueeze(-1) # (t b 1)
|
||||
m = rearrange(m, pattern)
|
||||
m = m.to(x)
|
||||
return x, m
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -61,9 +39,20 @@ class Model(nn.Module):
|
|||
|
||||
self.resnet = resnet18(pretrained=False)
|
||||
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
|
||||
|
||||
self.criterion = nn.CTCLoss(zero_infinity=True)
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
n_tokens,
|
||||
#top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
)
|
||||
|
||||
self.precision_metric = MulticlassPrecision(
|
||||
n_tokens,
|
||||
#top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
||||
|
@ -77,6 +66,7 @@ class Model(nn.Module):
|
|||
x = self.resnet( x_list )
|
||||
y = x.view(x.size(0), self.n_len, self.n_tokens)
|
||||
|
||||
# either of these should do, but my VALL-E forward pass uses this, so might as well keep to it
|
||||
# pred = y.argmax(dim=2)
|
||||
pred = Categorical(logits=y / sampling_temperature).sample()
|
||||
|
||||
|
@ -84,87 +74,20 @@ class Model(nn.Module):
|
|||
|
||||
if text is not None:
|
||||
y_list = rearrange(pad_sequence(text), "t b -> b t")
|
||||
|
||||
|
||||
loss = 0
|
||||
for i in range(self.n_len):
|
||||
if i >= y_list.shape[1]:
|
||||
break
|
||||
loss += F.cross_entropy( y[:, i], y_list[:, i] )
|
||||
|
||||
self.loss = dict(
|
||||
nll=loss
|
||||
)
|
||||
|
||||
return answer
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( pred, y_list ),
|
||||
precision = self.precision_metric( pred, y_list ),
|
||||
)
|
||||
|
||||
def example_usage():
|
||||
from ..config import cfg
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.trainer.check_for_oom = False
|
||||
|
||||
from functools import partial
|
||||
|
||||
from einops import repeat
|
||||
|
||||
from ..emb.qnt import decode_to_file
|
||||
from ..engines import Engine, Engines
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from .ar import AR
|
||||
from .nar import NAR
|
||||
|
||||
device = "cpu"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
def tokenize(content, lang_marker="en"):
|
||||
split = content.split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
||||
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 12,
|
||||
}
|
||||
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
|
||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||
|
||||
train = True
|
||||
|
||||
|
||||
|
||||
|
||||
def sample( name, steps=400 ):
|
||||
AR = None
|
||||
NAR = None
|
||||
|
||||
engines.eval()
|
||||
for name, engine in engines.items():
|
||||
if name[:2] == "ar":
|
||||
AR = engine
|
||||
elif name[:3] == "nar":
|
||||
NAR = engine
|
||||
|
||||
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
||||
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
|
||||
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
|
||||
|
||||
if train:
|
||||
sample("init", 15)
|
||||
|
||||
engines.train()
|
||||
t = trange(60)
|
||||
for i in t:
|
||||
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
|
||||
t.set_description(f"{stats}")
|
||||
else:
|
||||
for name, engine in engines.items():
|
||||
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
|
||||
|
||||
sample("final")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
return answer
|
|
@ -23,11 +23,13 @@ def train_feeder(engine, batch):
|
|||
engine( image=batch["image"], text=batch["text"] )
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= {k: v.item() for k, v in stat.items()}
|
||||
|
||||
return loss, stats
|
||||
|
||||
|
@ -55,7 +57,6 @@ def run_eval(engines, eval_name, dl):
|
|||
for batch in tqdm(dl):
|
||||
batch: dict = to_device(batch, cfg.device)
|
||||
|
||||
# if we're training both models, provide output for both
|
||||
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
|
||||
|
||||
for path, ref, hyp in zip(batch["path"], batch["text"], res):
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
__version__ = "0.0.1-dev20230804142130"
|
Loading…
Reference in New Issue
Block a user