tweaks, fixes, cleanup, added reporting accuracy/precision from the VALL-E trainer (which indirectly revealed a grody bug in the VALL-E trainer), some other cr*p

This commit is contained in:
mrq 2023-08-05 22:42:05 +00:00
parent 77a9625e93
commit 93987ea5d6
11 changed files with 56 additions and 124 deletions

2
.gitignore vendored
View File

@ -5,6 +5,6 @@ __pycache__
/.cache
/config
/*.egg-info
/vall_e/version.py
/image_classifier/version.py
/build
/.cache

View File

@ -4,7 +4,7 @@ This is a simple ResNet based image classifier for """specific images""", using
## Premise
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
@ -22,13 +22,17 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
## Inferencing
Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
### Continuous Usage
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
## Known Issues
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
* The evaluation / validation routine doesn't quite work.
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
## Strawmen

View File

@ -1,6 +1,6 @@
dataset:
training: [
"./data/captchas/"
"./data/images/"
]
validation: []
@ -12,17 +12,17 @@ dataset:
models:
_models:
- name: "captcha"
- name: "classifier"
tokens: 0
len: 6
hyperparameters:
batch_size: 256
gradient_accumulation_steps: 5
gradient_accumulation_steps: 64
gradient_clipping: 100
optimizer: Adamw
learning_rate: 5.0e-5
learning_rate: 1.0e-3
scheduler_type: ""
#scheduler_type: OneCycle

View File

@ -3,13 +3,13 @@ import base64
from io import BytesIO
from pathlib import Path
from .inference import CAPTCHA
from .inference import Classifier
from PIL import Image
from simple_http_server import route, server
def main():
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", action='store_true')
parser.add_argument("--port", type=int, default=9090)
parser.add_argument("--yaml", type=Path, default=None)
@ -18,15 +18,15 @@ def main():
parser.add_argument("--device", default="cuda")
args, unknown = parser.parse_known_args()
captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device )
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
if args.listen:
@route("/")
def inference( b64, temperature=1.0 ):
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
return { "answer": captcha.inference( image=image, temperature=args.temp ) }
return { "answer": classifier.inference( image=image, temperature=args.temp ) }
server.start(port=args.port)
else:
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--path", type=Path)
parser.add_argument("--base64", type=str)
parser.add_argument("--temp", type=float, default=1.0)
@ -39,7 +39,7 @@ def main():
else:
raise "Specify a --path or --base64."
answer = captcha.inference( image=image, temperature=args.temp )
answer = classifier.inference( image=image, temperature=args.temp )
print("Answer:", answer)
if __name__ == "__main__":

View File

@ -111,6 +111,7 @@ class Dataset:
temp: list[Path] = field(default_factory=lambda: [])
# de-implemented, because the data isn't that large to facilitate HDF5
hdf5_name: str = "data.h5"
use_hdf5: bool = False
@ -149,12 +150,12 @@ class Models:
class Hyperparameters:
batch_size: int = 8
gradient_accumulation_steps: int = 32
gradient_clipping: int = 100
gradient_clipping: int = 100 # to be implemented in the local backend
optimizer: str = "Adamw"
learning_rate: float = 3.25e-4
scheduler_type: str = ""
scheduler_type: str = "" # to be implemented in the local backend
scheduler_params: dict = field(default_factory=lambda: {})
@dataclass()
@ -317,7 +318,7 @@ class Trainer:
@dataclass()
class Inference:
use_vocos: bool = True
use_vocos: bool = True # artifact from the VALL-E trainer
@dataclass()
class BitsAndBytes:

View File

@ -1,7 +1,7 @@
# todo: clean this mess up
import copy
import h5py
# import h5py
import json
import logging
import numpy as np
@ -60,7 +60,7 @@ class Dataset(_Dataset):
self.training = training
self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)),
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@ -73,13 +73,14 @@ class Dataset(_Dataset):
def __getitem__(self, index):
path = self.paths[index]
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
try:
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
except Exception as e:
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
raise e
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype)
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
return dict(
index=index,
@ -192,7 +193,7 @@ def create_train_val_dataloader():
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
#subtrain_dataset.training_(False)
subtrain_dataset.training_(False)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)
@ -209,9 +210,11 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
"""
if __name__ == "__main__":
create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)
"""

View File

@ -48,6 +48,7 @@ if not distributed_initialized() and cfg.trainer.backend == "local":
init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch
# to-do: implement lr_sheduling
class Engine():
def __init__(self, *args, **kwargs):
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)

View File

@ -6,7 +6,7 @@ from .config import cfg
from .export import load_models
from .data import get_symmap, _get_symbols
class CAPTCHA():
class Classifier():
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
self.loading = True
self.device = device
@ -48,6 +48,6 @@ class CAPTCHA():
image = self.transform(image).to(self.device)
answer = self.model( image=[image], sampling_temperature=temperature )
answer = answer[0].replace('<s>', "").replace("</s>", "")
answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
return answer

View File

@ -16,28 +16,6 @@ from torchvision.models import resnet18
from ..data import get_symmap
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
"""
Args:
x_list: [(t d)]
Returns:
x: (? ? ?)
m: (? ? ?), same as x
"""
l = list(map(len, x_list))
x = rearrange(pad_sequence(x_list), pattern)
m = _create_mask(l, x_list[0].device)
m = m.t().unsqueeze(-1) # (t b 1)
m = rearrange(m, pattern)
m = m.to(x)
return x, m
class Model(nn.Module):
def __init__(
self,
@ -61,9 +39,20 @@ class Model(nn.Module):
self.resnet = resnet18(pretrained=False)
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
self.criterion = nn.CTCLoss(zero_infinity=True)
self.accuracy_metric = MulticlassAccuracy(
n_tokens,
#top_k=10,
average="micro",
multidim_average="global",
)
self.precision_metric = MulticlassPrecision(
n_tokens,
#top_k=10,
average="micro",
multidim_average="global",
)
def forward(
self,
@ -77,6 +66,7 @@ class Model(nn.Module):
x = self.resnet( x_list )
y = x.view(x.size(0), self.n_len, self.n_tokens)
# either of these should do, but my VALL-E forward pass uses this, so might as well keep to it
# pred = y.argmax(dim=2)
pred = Categorical(logits=y / sampling_temperature).sample()
@ -84,87 +74,20 @@ class Model(nn.Module):
if text is not None:
y_list = rearrange(pad_sequence(text), "t b -> b t")
loss = 0
for i in range(self.n_len):
if i >= y_list.shape[1]:
break
loss += F.cross_entropy( y[:, i], y_list[:, i] )
self.loss = dict(
nll=loss
)
return answer
self.stats = dict(
acc = self.accuracy_metric( pred, y_list ),
precision = self.precision_metric( pred, y_list ),
)
def example_usage():
from ..config import cfg
cfg.trainer.backend = "local"
cfg.trainer.check_for_oom = False
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
from ..engines import Engine, Engines
from tqdm import tqdm, trange
from .ar import AR
from .nar import NAR
device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True
def sample( name, steps=400 ):
AR = None
NAR = None
engines.eval()
for name, engine in engines.items():
if name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
resps_list = [r.unsqueeze(-1) for r in resps_list]
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
if train:
sample("init", 15)
engines.train()
t = trange(60)
for i in t:
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
t.set_description(f"{stats}")
else:
for name, engine in engines.items():
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
sample("final")
if __name__ == "__main__":
example_usage()
return answer

View File

@ -23,11 +23,13 @@ def train_feeder(engine, batch):
engine( image=batch["image"], text=batch["text"] )
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
return loss, stats
@ -55,7 +57,6 @@ def run_eval(engines, eval_name, dl):
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
# if we're training both models, provide output for both
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
for path, ref, hyp in zip(batch["path"], batch["text"], res):

View File

@ -1 +0,0 @@
__version__ = "0.0.1-dev20230804142130"