# todo: clean this mess up from .config import cfg from .data import create_train_val_dataloader from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils.trainer import load_engines import json import logging import random import torch import torch.nn.functional as F import traceback from collections import defaultdict from PIL import Image from tqdm import tqdm _logger = logging.getLogger(__name__) def train_feeder(engine, batch): engine( image=batch["image"], text=batch["text"] ) losses = engine.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} return loss, stats @torch.inference_mode() def run_eval(engines, eval_name, dl): engines_stats = { 'eval': eval_name } model = None names = [] for name, engine in engines.items(): names.append(name) model = engine break stats = defaultdict(list) stats['loss'] = [] def process( name, batch, resps_list ): for path, ref, hyp in zip(batch["path"], batch["text"], hyp): continue for batch in tqdm(dl): batch: dict = to_device(batch, cfg.device) # if we're training both models, provide output for both res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature ) for path, ref, hyp in zip(batch["path"], batch["text"], res): hyp = hyp.replace('', "").replace("", "") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png") hyp_path.parent.mkdir(parents=True, exist_ok=True) image = Image.open(path).convert('RGB') image.save(hyp_path) losses = engine.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum().item() stats['loss'].append(loss) stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats.update(flatten_dict({ name: stats })) iteration = engines.global_step engines_stats['it'] = iteration engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") def main(): setup_logging(cfg.log_dir) train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def eval_fn(engines): try: run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "val", val_dl) except Exception as e: print("Error occurred while performing eval:", str(e)) print(traceback.format_exc()) do_gc() trainer.train( train_dl=train_dl, train_feeder=train_feeder, eval_fn=eval_fn, ) if __name__ == "__main__": main()