# todo: clean this mess up from .config import cfg from .data import create_train_val_dataloader from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils.distributed import is_global_leader import json import logging import random import torch import torch.nn.functional as F import traceback import shutil from collections import defaultdict from tqdm import tqdm import argparse from PIL import Image, ImageDraw _logger = logging.getLogger(__name__) def train_feeder(engine, batch): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): batch_size = len(batch["text"]) engine.current_batch_size = batch_size engine( image=batch["image"], text=batch["text"] ) losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats") loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in stat.items()} engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) return loss, stats @torch.inference_mode() def run_eval(engines, eval_name, dl): stats = defaultdict(list) stats['loss'] = [] def process( name, batch, res, loss ): for path, ref, hyp in zip(batch["path"], batch["text"], res): hyp = hyp.replace('', "").replace("", "") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png") hyp_path.parent.mkdir(parents=True, exist_ok=True) image = Image.open(path).convert('RGB') image.save(hyp_path) stats['loss'].append(loss) processed = 0 while processed < cfg.evaluation.size: batch = to_device(next(iter(dl)), cfg.device) # limit to eval batch size in the event we somehow have a weird dataloader for key in batch.keys(): batch[key] = batch[key][:cfg.evaluation.batch_size] processed += len(batch["text"]) for name in engines: engine = engines[name] res = engine( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature ) losses = engine.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum().item() process( name, batch, res, loss ) stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats = { f'{name}.{eval_name}': stats, "it": engines.global_step, } #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") def train(): parser = argparse.ArgumentParser("ResNet Image Classifier") parser.add_argument("--eval", action="store_true", default=None) args, unknown = parser.parse_known_args() # create log folder setup_logging(cfg.log_dir) # copy config yaml to backup if cfg.yaml_path is not None and is_global_leader(): shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" ) train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def eval_fn(engines): do_gc() engines.eval() # wrapped in a try block because it's sometimes prone to breaking try: run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "val", val_dl) except Exception as e: _logger.warning(f"Error occurred while performing eval: {str(e)}") _logger.warning(traceback.format_exc()) engines.train() do_gc() if args.eval: return eval_fn(engines=trainer.load_engines()) """ if cfg.trainer.load_webui: from .webui import start start(lock=False) """ trainer.train( train_dl=train_dl, train_feeder=train_feeder, eval_fn=eval_fn, ) if __name__ == "__main__": # to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"` train()