tweaks, fixes, cleanup, added reporting accuracy/precision from the VALL-E trainer (which indirectly revealed a grody bug in the VALL-E trainer), some other cr*p

This commit is contained in:
mrq 2023-08-05 22:42:05 +00:00
parent 77a9625e93
commit 93987ea5d6
11 changed files with 56 additions and 124 deletions

2
.gitignore vendored
View File

@ -5,6 +5,6 @@ __pycache__
/.cache /.cache
/config /config
/*.egg-info /*.egg-info
/vall_e/version.py /image_classifier/version.py
/build /build
/.cache /.cache

View File

@ -4,7 +4,7 @@ This is a simple ResNet based image classifier for """specific images""", using
## Premise ## Premise
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born. This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`. This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
@ -22,13 +22,17 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
## Inferencing ## Inferencing
Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0` Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
### Continuous Usage
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
## Known Issues ## Known Issues
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed. * Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
* The evaluation / validation routine doesn't quite work.
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed. * Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
## Strawmen ## Strawmen

View File

@ -1,6 +1,6 @@
dataset: dataset:
training: [ training: [
"./data/captchas/" "./data/images/"
] ]
validation: [] validation: []
@ -12,17 +12,17 @@ dataset:
models: models:
_models: _models:
- name: "captcha" - name: "classifier"
tokens: 0 tokens: 0
len: 6 len: 6
hyperparameters: hyperparameters:
batch_size: 256 batch_size: 256
gradient_accumulation_steps: 5 gradient_accumulation_steps: 64
gradient_clipping: 100 gradient_clipping: 100
optimizer: Adamw optimizer: Adamw
learning_rate: 5.0e-5 learning_rate: 1.0e-3
scheduler_type: "" scheduler_type: ""
#scheduler_type: OneCycle #scheduler_type: OneCycle

View File

@ -3,13 +3,13 @@ import base64
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from .inference import CAPTCHA from .inference import Classifier
from PIL import Image from PIL import Image
from simple_http_server import route, server from simple_http_server import route, server
def main(): def main():
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", action='store_true') parser.add_argument("--listen", action='store_true')
parser.add_argument("--port", type=int, default=9090) parser.add_argument("--port", type=int, default=9090)
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
@ -18,15 +18,15 @@ def main():
parser.add_argument("--device", default="cuda") parser.add_argument("--device", default="cuda")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device ) classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
if args.listen: if args.listen:
@route("/") @route("/")
def inference( b64, temperature=1.0 ): def inference( b64, temperature=1.0 ):
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
return { "answer": captcha.inference( image=image, temperature=args.temp ) } return { "answer": classifier.inference( image=image, temperature=args.temp ) }
server.start(port=args.port) server.start(port=args.port)
else: else:
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--path", type=Path) parser.add_argument("--path", type=Path)
parser.add_argument("--base64", type=str) parser.add_argument("--base64", type=str)
parser.add_argument("--temp", type=float, default=1.0) parser.add_argument("--temp", type=float, default=1.0)
@ -39,7 +39,7 @@ def main():
else: else:
raise "Specify a --path or --base64." raise "Specify a --path or --base64."
answer = captcha.inference( image=image, temperature=args.temp ) answer = classifier.inference( image=image, temperature=args.temp )
print("Answer:", answer) print("Answer:", answer)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -111,6 +111,7 @@ class Dataset:
temp: list[Path] = field(default_factory=lambda: []) temp: list[Path] = field(default_factory=lambda: [])
# de-implemented, because the data isn't that large to facilitate HDF5
hdf5_name: str = "data.h5" hdf5_name: str = "data.h5"
use_hdf5: bool = False use_hdf5: bool = False
@ -149,12 +150,12 @@ class Models:
class Hyperparameters: class Hyperparameters:
batch_size: int = 8 batch_size: int = 8
gradient_accumulation_steps: int = 32 gradient_accumulation_steps: int = 32
gradient_clipping: int = 100 gradient_clipping: int = 100 # to be implemented in the local backend
optimizer: str = "Adamw" optimizer: str = "Adamw"
learning_rate: float = 3.25e-4 learning_rate: float = 3.25e-4
scheduler_type: str = "" scheduler_type: str = "" # to be implemented in the local backend
scheduler_params: dict = field(default_factory=lambda: {}) scheduler_params: dict = field(default_factory=lambda: {})
@dataclass() @dataclass()
@ -317,7 +318,7 @@ class Trainer:
@dataclass() @dataclass()
class Inference: class Inference:
use_vocos: bool = True use_vocos: bool = True # artifact from the VALL-E trainer
@dataclass() @dataclass()
class BitsAndBytes: class BitsAndBytes:

View File

@ -1,7 +1,7 @@
# todo: clean this mess up # todo: clean this mess up
import copy import copy
import h5py # import h5py
import json import json
import logging import logging
import numpy as np import numpy as np
@ -60,7 +60,7 @@ class Dataset(_Dataset):
self.training = training self.training = training
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)), #transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]) ])
@ -73,13 +73,14 @@ class Dataset(_Dataset):
def __getitem__(self, index): def __getitem__(self, index):
path = self.paths[index] path = self.paths[index]
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
try: try:
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8) text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
except Exception as e: except Exception as e:
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem) print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
raise e raise e
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
return dict( return dict(
index=index, index=index,
@ -192,7 +193,7 @@ def create_train_val_dataloader():
subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size) subtrain_dataset.head_(cfg.evaluation.size)
#subtrain_dataset.training_(False) subtrain_dataset.training_(False)
train_dl = _create_dataloader(train_dataset, training=True) train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False) val_dl = _create_dataloader(val_dataset, training=False)
@ -209,9 +210,11 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl return train_dl, subtrain_dl, val_dl
"""
if __name__ == "__main__": if __name__ == "__main__":
create_dataset_hdf5() create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
sample = train_dl.dataset[0] sample = train_dl.dataset[0]
print(sample) print(sample)
"""

View File

@ -48,6 +48,7 @@ if not distributed_initialized() and cfg.trainer.backend == "local":
init_distributed(torch.distributed.init_process_group) init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch # A very naive engine implementation using barebones PyTorch
# to-do: implement lr_sheduling
class Engine(): class Engine():
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype) self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)

View File

@ -6,7 +6,7 @@ from .config import cfg
from .export import load_models from .export import load_models
from .data import get_symmap, _get_symbols from .data import get_symmap, _get_symbols
class CAPTCHA(): class Classifier():
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ): def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
self.loading = True self.loading = True
self.device = device self.device = device
@ -48,6 +48,6 @@ class CAPTCHA():
image = self.transform(image).to(self.device) image = self.transform(image).to(self.device)
answer = self.model( image=[image], sampling_temperature=temperature ) answer = self.model( image=[image], sampling_temperature=temperature )
answer = answer[0].replace('<s>', "").replace("</s>", "") answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
return answer return answer

View File

@ -16,28 +16,6 @@ from torchvision.models import resnet18
from ..data import get_symmap from ..data import get_symmap
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
"""
Args:
x_list: [(t d)]
Returns:
x: (? ? ?)
m: (? ? ?), same as x
"""
l = list(map(len, x_list))
x = rearrange(pad_sequence(x_list), pattern)
m = _create_mask(l, x_list[0].device)
m = m.t().unsqueeze(-1) # (t b 1)
m = rearrange(m, pattern)
m = m.to(x)
return x, m
class Model(nn.Module): class Model(nn.Module):
def __init__( def __init__(
self, self,
@ -61,9 +39,20 @@ class Model(nn.Module):
self.resnet = resnet18(pretrained=False) self.resnet = resnet18(pretrained=False)
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len ) self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
self.criterion = nn.CTCLoss(zero_infinity=True)
self.accuracy_metric = MulticlassAccuracy(
n_tokens,
#top_k=10,
average="micro",
multidim_average="global",
)
self.precision_metric = MulticlassPrecision(
n_tokens,
#top_k=10,
average="micro",
multidim_average="global",
)
def forward( def forward(
self, self,
@ -77,6 +66,7 @@ class Model(nn.Module):
x = self.resnet( x_list ) x = self.resnet( x_list )
y = x.view(x.size(0), self.n_len, self.n_tokens) y = x.view(x.size(0), self.n_len, self.n_tokens)
# either of these should do, but my VALL-E forward pass uses this, so might as well keep to it
# pred = y.argmax(dim=2) # pred = y.argmax(dim=2)
pred = Categorical(logits=y / sampling_temperature).sample() pred = Categorical(logits=y / sampling_temperature).sample()
@ -84,87 +74,20 @@ class Model(nn.Module):
if text is not None: if text is not None:
y_list = rearrange(pad_sequence(text), "t b -> b t") y_list = rearrange(pad_sequence(text), "t b -> b t")
loss = 0 loss = 0
for i in range(self.n_len): for i in range(self.n_len):
if i >= y_list.shape[1]:
break
loss += F.cross_entropy( y[:, i], y_list[:, i] ) loss += F.cross_entropy( y[:, i], y_list[:, i] )
self.loss = dict( self.loss = dict(
nll=loss nll=loss
) )
return answer self.stats = dict(
acc = self.accuracy_metric( pred, y_list ),
precision = self.precision_metric( pred, y_list ),
)
def example_usage(): return answer
from ..config import cfg
cfg.trainer.backend = "local"
cfg.trainer.check_for_oom = False
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
from ..engines import Engine, Engines
from tqdm import tqdm, trange
from .ar import AR
from .nar import NAR
device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True
def sample( name, steps=400 ):
AR = None
NAR = None
engines.eval()
for name, engine in engines.items():
if name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
resps_list = [r.unsqueeze(-1) for r in resps_list]
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
if train:
sample("init", 15)
engines.train()
t = trange(60)
for i in t:
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
t.set_description(f"{stats}")
else:
for name, engine in engines.items():
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
sample("final")
if __name__ == "__main__":
example_usage()

View File

@ -23,11 +23,13 @@ def train_feeder(engine, batch):
engine( image=batch["image"], text=batch["text"] ) engine( image=batch["image"], text=batch["text"] )
losses = engine.gather_attribute("loss") losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum() loss = torch.stack([*losses.values()]).sum()
stats = {} stats = {}
stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
return loss, stats return loss, stats
@ -55,7 +57,6 @@ def run_eval(engines, eval_name, dl):
for batch in tqdm(dl): for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device) batch: dict = to_device(batch, cfg.device)
# if we're training both models, provide output for both
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature ) res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
for path, ref, hyp in zip(batch["path"], batch["text"], res): for path, ref, hyp in zip(batch["path"], batch["text"], res):

View File

@ -1 +0,0 @@
__version__ = "0.0.1-dev20230804142130"