resnet-classifier/image_classifier/train.py

137 lines
3.8 KiB
Python
Executable File

# todo: clean this mess up
from .config import cfg
from .data import create_train_val_dataloader
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils.distributed import is_global_leader
import json
import logging
import random
import torch
import torch.nn.functional as F
import traceback
import shutil
from collections import defaultdict
from tqdm import tqdm
import argparse
from PIL import Image, ImageDraw
_logger = logging.getLogger(__name__)
def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
engine( image=batch["image"], text=batch["text"] )
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
return loss, stats
@torch.inference_mode()
def run_eval(engines, eval_name, dl):
stats = defaultdict(list)
stats['loss'] = []
def process( name, batch, res, loss ):
for path, ref, hyp in zip(batch["path"], batch["text"], res):
hyp = hyp.replace('<s>', "").replace("</s>", "")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
image = Image.open(path).convert('RGB')
image.save(hyp_path)
stats['loss'].append(loss)
processed = 0
while processed < cfg.evaluation.size:
batch = to_device(next(iter(dl)), cfg.device)
# limit to eval batch size in the event we somehow have a weird dataloader
for key in batch.keys():
batch[key] = batch[key][:cfg.evaluation.batch_size]
processed += len(batch["text"])
for name in engines:
engine = engines[name]
res = engine( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum().item()
process( name, batch, res, loss )
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats = {
f'{name}.{eval_name}': stats,
"it": engines.global_step,
}
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def train():
parser = argparse.ArgumentParser("ResNet Image Classifier")
parser.add_argument("--eval", action="store_true", default=None)
args, unknown = parser.parse_known_args()
# create log folder
setup_logging(cfg.log_dir)
# copy config yaml to backup
if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def eval_fn(engines):
do_gc()
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
try:
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
except Exception as e:
_logger.warning(f"Error occurred while performing eval: {str(e)}")
_logger.warning(traceback.format_exc())
engines.train()
do_gc()
if args.eval:
return eval_fn(engines=trainer.load_engines())
"""
if cfg.trainer.load_webui:
from .webui import start
start(lock=False)
"""
trainer.train(
train_dl=train_dl,
train_feeder=train_feeder,
eval_fn=eval_fn,
)
if __name__ == "__main__":
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
train()