tweaks, fixes, cleanup, added reporting accuracy/precision from the VALL-E trainer (which indirectly revealed a grody bug in the VALL-E trainer), some other cr*p
This commit is contained in:
parent
77a9625e93
commit
93987ea5d6
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -5,6 +5,6 @@ __pycache__
|
||||||
/.cache
|
/.cache
|
||||||
/config
|
/config
|
||||||
/*.egg-info
|
/*.egg-info
|
||||||
/vall_e/version.py
|
/image_classifier/version.py
|
||||||
/build
|
/build
|
||||||
/.cache
|
/.cache
|
10
README.md
10
README.md
|
@ -4,7 +4,7 @@ This is a simple ResNet based image classifier for """specific images""", using
|
||||||
|
|
||||||
## Premise
|
## Premise
|
||||||
|
|
||||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
|
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a minor problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
|
||||||
|
|
||||||
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
|
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
|
||||||
|
|
||||||
|
@ -22,13 +22,17 @@ This is by no ways state of the art, as it just leverages an existing ResNet arc
|
||||||
|
|
||||||
## Inferencing
|
## Inferencing
|
||||||
|
|
||||||
Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
Simply invoke the inferencer with the following command: `python3 -m image_classifier --path="./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
||||||
|
|
||||||
|
### Continuous Usage
|
||||||
|
|
||||||
|
If you're looking to continuously classifier trained images, use `python3 -m image_classifier --listen --port=7860 yaml="./data/config.yaml" --temp=1.0` instead to enable a light webserver using `simple_http_server`. Send a `GET` request to `http://127.0.0.1:7860/?b64={base64 encoded image string}` and a JSON response will be returned with the classified label.
|
||||||
|
|
||||||
## Known Issues
|
## Known Issues
|
||||||
|
|
||||||
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
|
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
|
||||||
* The evaluation / validation routine doesn't quite work.
|
|
||||||
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
|
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
|
||||||
|
* Web server doesn't emit `content-type: application/json`, nor accepts JSON `POST`s at the moment.
|
||||||
|
|
||||||
## Strawmen
|
## Strawmen
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
dataset:
|
dataset:
|
||||||
training: [
|
training: [
|
||||||
"./data/captchas/"
|
"./data/images/"
|
||||||
]
|
]
|
||||||
|
|
||||||
validation: []
|
validation: []
|
||||||
|
@ -12,17 +12,17 @@ dataset:
|
||||||
|
|
||||||
models:
|
models:
|
||||||
_models:
|
_models:
|
||||||
- name: "captcha"
|
- name: "classifier"
|
||||||
tokens: 0
|
tokens: 0
|
||||||
len: 6
|
len: 6
|
||||||
|
|
||||||
hyperparameters:
|
hyperparameters:
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
gradient_accumulation_steps: 5
|
gradient_accumulation_steps: 64
|
||||||
gradient_clipping: 100
|
gradient_clipping: 100
|
||||||
|
|
||||||
optimizer: Adamw
|
optimizer: Adamw
|
||||||
learning_rate: 5.0e-5
|
learning_rate: 1.0e-3
|
||||||
|
|
||||||
scheduler_type: ""
|
scheduler_type: ""
|
||||||
#scheduler_type: OneCycle
|
#scheduler_type: OneCycle
|
||||||
|
|
|
@ -3,13 +3,13 @@ import base64
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .inference import CAPTCHA
|
from .inference import Classifier
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from simple_http_server import route, server
|
from simple_http_server import route, server
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
parser.add_argument("--listen", action='store_true')
|
parser.add_argument("--listen", action='store_true')
|
||||||
parser.add_argument("--port", type=int, default=9090)
|
parser.add_argument("--port", type=int, default=9090)
|
||||||
parser.add_argument("--yaml", type=Path, default=None)
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
|
@ -18,15 +18,15 @@ def main():
|
||||||
parser.add_argument("--device", default="cuda")
|
parser.add_argument("--device", default="cuda")
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device )
|
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
|
||||||
if args.listen:
|
if args.listen:
|
||||||
@route("/")
|
@route("/")
|
||||||
def inference( b64, temperature=1.0 ):
|
def inference( b64, temperature=1.0 ):
|
||||||
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
|
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
|
||||||
return { "answer": captcha.inference( image=image, temperature=args.temp ) }
|
return { "answer": classifier.inference( image=image, temperature=args.temp ) }
|
||||||
server.start(port=args.port)
|
server.start(port=args.port)
|
||||||
else:
|
else:
|
||||||
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
parser.add_argument("--path", type=Path)
|
parser.add_argument("--path", type=Path)
|
||||||
parser.add_argument("--base64", type=str)
|
parser.add_argument("--base64", type=str)
|
||||||
parser.add_argument("--temp", type=float, default=1.0)
|
parser.add_argument("--temp", type=float, default=1.0)
|
||||||
|
@ -39,7 +39,7 @@ def main():
|
||||||
else:
|
else:
|
||||||
raise "Specify a --path or --base64."
|
raise "Specify a --path or --base64."
|
||||||
|
|
||||||
answer = captcha.inference( image=image, temperature=args.temp )
|
answer = classifier.inference( image=image, temperature=args.temp )
|
||||||
print("Answer:", answer)
|
print("Answer:", answer)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -111,6 +111,7 @@ class Dataset:
|
||||||
|
|
||||||
temp: list[Path] = field(default_factory=lambda: [])
|
temp: list[Path] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
# de-implemented, because the data isn't that large to facilitate HDF5
|
||||||
hdf5_name: str = "data.h5"
|
hdf5_name: str = "data.h5"
|
||||||
use_hdf5: bool = False
|
use_hdf5: bool = False
|
||||||
|
|
||||||
|
@ -149,12 +150,12 @@ class Models:
|
||||||
class Hyperparameters:
|
class Hyperparameters:
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
gradient_accumulation_steps: int = 32
|
gradient_accumulation_steps: int = 32
|
||||||
gradient_clipping: int = 100
|
gradient_clipping: int = 100 # to be implemented in the local backend
|
||||||
|
|
||||||
optimizer: str = "Adamw"
|
optimizer: str = "Adamw"
|
||||||
learning_rate: float = 3.25e-4
|
learning_rate: float = 3.25e-4
|
||||||
|
|
||||||
scheduler_type: str = ""
|
scheduler_type: str = "" # to be implemented in the local backend
|
||||||
scheduler_params: dict = field(default_factory=lambda: {})
|
scheduler_params: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
@ -317,7 +318,7 @@ class Trainer:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Inference:
|
class Inference:
|
||||||
use_vocos: bool = True
|
use_vocos: bool = True # artifact from the VALL-E trainer
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BitsAndBytes:
|
class BitsAndBytes:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# todo: clean this mess up
|
# todo: clean this mess up
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import h5py
|
# import h5py
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -60,7 +60,7 @@ class Dataset(_Dataset):
|
||||||
self.training = training
|
self.training = training
|
||||||
|
|
||||||
self.transform = transforms.Compose([
|
self.transform = transforms.Compose([
|
||||||
transforms.Resize((self.height, self.width)),
|
#transforms.Resize((self.height, self.width)), # for some reason, running the validation dataset breaks when this is set. all images *should* be normalized anyhow
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
])
|
])
|
||||||
|
@ -73,13 +73,14 @@ class Dataset(_Dataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
path = self.paths[index]
|
path = self.paths[index]
|
||||||
|
|
||||||
|
# stupid try/except when the original VALL-E training framework was able to insert foreign symbols into the symmap, but that functionality isn't really necessary here
|
||||||
try:
|
try:
|
||||||
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
|
text = torch.tensor([*map(self.symmap.get, _get_symbols(path.stem))]).to(torch.uint8)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
|
print("Invalid symbol:", _get_symbols(path.stem), [*map(self.symmap.get, _get_symbols(path.stem))], path.stem)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype)
|
image = self.transform(Image.open(path).convert('RGB')).to(cfg.trainer.dtype) # resnet has to be RGB
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
|
@ -192,7 +193,7 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||||
subtrain_dataset.head_(cfg.evaluation.size)
|
subtrain_dataset.head_(cfg.evaluation.size)
|
||||||
#subtrain_dataset.training_(False)
|
subtrain_dataset.training_(False)
|
||||||
|
|
||||||
train_dl = _create_dataloader(train_dataset, training=True)
|
train_dl = _create_dataloader(train_dataset, training=True)
|
||||||
val_dl = _create_dataloader(val_dataset, training=False)
|
val_dl = _create_dataloader(val_dataset, training=False)
|
||||||
|
@ -209,9 +210,11 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
return train_dl, subtrain_dl, val_dl
|
return train_dl, subtrain_dl, val_dl
|
||||||
|
|
||||||
|
"""
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
create_dataset_hdf5()
|
create_dataset_hdf5()
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
sample = train_dl.dataset[0]
|
sample = train_dl.dataset[0]
|
||||||
print(sample)
|
print(sample)
|
||||||
|
"""
|
||||||
|
|
|
@ -48,6 +48,7 @@ if not distributed_initialized() and cfg.trainer.backend == "local":
|
||||||
init_distributed(torch.distributed.init_process_group)
|
init_distributed(torch.distributed.init_process_group)
|
||||||
|
|
||||||
# A very naive engine implementation using barebones PyTorch
|
# A very naive engine implementation using barebones PyTorch
|
||||||
|
# to-do: implement lr_sheduling
|
||||||
class Engine():
|
class Engine():
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .config import cfg
|
||||||
from .export import load_models
|
from .export import load_models
|
||||||
from .data import get_symmap, _get_symbols
|
from .data import get_symmap, _get_symbols
|
||||||
|
|
||||||
class CAPTCHA():
|
class Classifier():
|
||||||
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
|
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
|
||||||
self.loading = True
|
self.loading = True
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -48,6 +48,6 @@ class CAPTCHA():
|
||||||
image = self.transform(image).to(self.device)
|
image = self.transform(image).to(self.device)
|
||||||
|
|
||||||
answer = self.model( image=[image], sampling_temperature=temperature )
|
answer = self.model( image=[image], sampling_temperature=temperature )
|
||||||
answer = answer[0].replace('<s>', "").replace("</s>", "")
|
answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
|
||||||
|
|
||||||
return answer
|
return answer
|
|
@ -16,28 +16,6 @@ from torchvision.models import resnet18
|
||||||
|
|
||||||
from ..data import get_symmap
|
from ..data import get_symmap
|
||||||
|
|
||||||
def _create_mask(l, device):
|
|
||||||
"""1 is valid region and 0 is invalid."""
|
|
||||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
|
||||||
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
|
||||||
return (seq < stop).float() # (b t)
|
|
||||||
|
|
||||||
def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x_list: [(t d)]
|
|
||||||
Returns:
|
|
||||||
x: (? ? ?)
|
|
||||||
m: (? ? ?), same as x
|
|
||||||
"""
|
|
||||||
l = list(map(len, x_list))
|
|
||||||
x = rearrange(pad_sequence(x_list), pattern)
|
|
||||||
m = _create_mask(l, x_list[0].device)
|
|
||||||
m = m.t().unsqueeze(-1) # (t b 1)
|
|
||||||
m = rearrange(m, pattern)
|
|
||||||
m = m.to(x)
|
|
||||||
return x, m
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -62,8 +40,19 @@ class Model(nn.Module):
|
||||||
self.resnet = resnet18(pretrained=False)
|
self.resnet = resnet18(pretrained=False)
|
||||||
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
|
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
|
||||||
|
|
||||||
self.criterion = nn.CTCLoss(zero_infinity=True)
|
self.accuracy_metric = MulticlassAccuracy(
|
||||||
|
n_tokens,
|
||||||
|
#top_k=10,
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.precision_metric = MulticlassPrecision(
|
||||||
|
n_tokens,
|
||||||
|
#top_k=10,
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
|
@ -77,6 +66,7 @@ class Model(nn.Module):
|
||||||
x = self.resnet( x_list )
|
x = self.resnet( x_list )
|
||||||
y = x.view(x.size(0), self.n_len, self.n_tokens)
|
y = x.view(x.size(0), self.n_len, self.n_tokens)
|
||||||
|
|
||||||
|
# either of these should do, but my VALL-E forward pass uses this, so might as well keep to it
|
||||||
# pred = y.argmax(dim=2)
|
# pred = y.argmax(dim=2)
|
||||||
pred = Categorical(logits=y / sampling_temperature).sample()
|
pred = Categorical(logits=y / sampling_temperature).sample()
|
||||||
|
|
||||||
|
@ -87,84 +77,17 @@ class Model(nn.Module):
|
||||||
|
|
||||||
loss = 0
|
loss = 0
|
||||||
for i in range(self.n_len):
|
for i in range(self.n_len):
|
||||||
|
if i >= y_list.shape[1]:
|
||||||
|
break
|
||||||
loss += F.cross_entropy( y[:, i], y_list[:, i] )
|
loss += F.cross_entropy( y[:, i], y_list[:, i] )
|
||||||
|
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
nll=loss
|
nll=loss
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.stats = dict(
|
||||||
|
acc = self.accuracy_metric( pred, y_list ),
|
||||||
|
precision = self.precision_metric( pred, y_list ),
|
||||||
|
)
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
def example_usage():
|
|
||||||
from ..config import cfg
|
|
||||||
cfg.trainer.backend = "local"
|
|
||||||
cfg.trainer.check_for_oom = False
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from einops import repeat
|
|
||||||
|
|
||||||
from ..emb.qnt import decode_to_file
|
|
||||||
from ..engines import Engine, Engines
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
|
|
||||||
from .ar import AR
|
|
||||||
from .nar import NAR
|
|
||||||
|
|
||||||
device = "cpu"
|
|
||||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
|
||||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
|
||||||
def tokenize(content, lang_marker="en"):
|
|
||||||
split = content.split(" ")
|
|
||||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
|
||||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
'n_tokens': 1024,
|
|
||||||
'd_model': 1024,
|
|
||||||
'n_heads': 16,
|
|
||||||
'n_layers': 12,
|
|
||||||
}
|
|
||||||
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
|
|
||||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
|
||||||
|
|
||||||
train = True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sample( name, steps=400 ):
|
|
||||||
AR = None
|
|
||||||
NAR = None
|
|
||||||
|
|
||||||
engines.eval()
|
|
||||||
for name, engine in engines.items():
|
|
||||||
if name[:2] == "ar":
|
|
||||||
AR = engine
|
|
||||||
elif name[:3] == "nar":
|
|
||||||
NAR = engine
|
|
||||||
|
|
||||||
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
|
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
|
||||||
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
|
||||||
|
|
||||||
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
|
|
||||||
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
|
|
||||||
|
|
||||||
if train:
|
|
||||||
sample("init", 15)
|
|
||||||
|
|
||||||
engines.train()
|
|
||||||
t = trange(60)
|
|
||||||
for i in t:
|
|
||||||
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
|
|
||||||
t.set_description(f"{stats}")
|
|
||||||
else:
|
|
||||||
for name, engine in engines.items():
|
|
||||||
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
|
|
||||||
|
|
||||||
sample("final")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
example_usage()
|
|
||||||
|
|
|
@ -23,11 +23,13 @@ def train_feeder(engine, batch):
|
||||||
engine( image=batch["image"], text=batch["text"] )
|
engine( image=batch["image"], text=batch["text"] )
|
||||||
|
|
||||||
losses = engine.gather_attribute("loss")
|
losses = engine.gather_attribute("loss")
|
||||||
|
stat = engine.gather_attribute("stats")
|
||||||
|
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
stats |= {k: v.item() for k, v in stat.items()}
|
||||||
|
|
||||||
return loss, stats
|
return loss, stats
|
||||||
|
|
||||||
|
@ -55,7 +57,6 @@ def run_eval(engines, eval_name, dl):
|
||||||
for batch in tqdm(dl):
|
for batch in tqdm(dl):
|
||||||
batch: dict = to_device(batch, cfg.device)
|
batch: dict = to_device(batch, cfg.device)
|
||||||
|
|
||||||
# if we're training both models, provide output for both
|
|
||||||
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
|
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
|
||||||
|
|
||||||
for path, ref, hyp in zip(batch["path"], batch["text"], res):
|
for path, ref, hyp in zip(batch["path"], batch["text"], res):
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
__version__ = "0.0.1-dev20230804142130"
|
|
Loading…
Reference in New Issue
Block a user