resnet-classifier/image_classifier/models/base.py
2023-08-05 03:48:06 +00:00

171 lines
6.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import torch
import torch.nn.functional as F
import traceback
from typing import Literal, overload
from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from torchvision.models import resnet18
from ..data import get_symmap
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
"""
Args:
x_list: [(t d)]
Returns:
x: (? ? ?)
m: (? ? ?), same as x
"""
l = list(map(len, x_list))
x = rearrange(pad_sequence(x_list), pattern)
m = _create_mask(l, x_list[0].device)
m = m.t().unsqueeze(-1) # (t b 1)
m = rearrange(m, pattern)
m = m.to(x)
return x, m
class Model(nn.Module):
def __init__(
self,
n_tokens: int = 0, # number of token types
n_len: int = 6, # how long a sequence can be
d_model: int = 512,
):
super().__init__()
_symmap = get_symmap()
self.symmap = { f'{v}': k for k, v in _symmap.items() }
self.symmap['0'] = ""
if n_tokens == 0:
n_tokens = len(_symmap.keys())
self.n_tokens = n_tokens
self.n_len = n_len + 2 # start/stop tokens
self.d_model = d_model
self.resnet = resnet18(pretrained=False)
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
self.criterion = nn.CTCLoss(zero_infinity=True)
def forward(
self,
image,
text = None,
sampling_temperature: float = 1.0,
):
x_list = torch.stack( image, dim=0 )
x = self.resnet( x_list )
y = x.view(x.size(0), self.n_len, self.n_tokens)
# pred = y.argmax(dim=2)
pred = Categorical(logits=y / sampling_temperature).sample()
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
if text is not None:
y_list = rearrange(pad_sequence(text), "t b -> b t")
loss = 0
for i in range(self.n_len):
loss += F.cross_entropy( y[:, i], y_list[:, i] )
self.loss = dict(
nll=loss
)
return answer
def example_usage():
from ..config import cfg
cfg.trainer.backend = "local"
cfg.trainer.check_for_oom = False
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
from ..engines import Engine, Engines
from tqdm import tqdm, trange
from .ar import AR
from .nar import NAR
device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True
def sample( name, steps=400 ):
AR = None
NAR = None
engines.eval()
for name, engine in engines.items():
if name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
resps_list = [r.unsqueeze(-1) for r in resps_list]
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
if train:
sample("init", 15)
engines.train()
t = trange(60)
for i in t:
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
t.set_description(f"{stats}")
else:
for name, engine in engines.items():
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
sample("final")
if __name__ == "__main__":
example_usage()