106 lines
2.6 KiB
Python
Executable File
106 lines
2.6 KiB
Python
Executable File
# todo: clean this mess up
|
|
|
|
from .config import cfg
|
|
from .data import create_train_val_dataloader
|
|
|
|
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
|
from .utils.trainer import load_engines
|
|
|
|
import json
|
|
import logging
|
|
import random
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import traceback
|
|
|
|
from collections import defaultdict
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
def train_feeder(engine, batch):
|
|
engine( image=batch["image"], text=batch["text"] )
|
|
|
|
losses = engine.gather_attribute("loss")
|
|
|
|
loss = torch.stack([*losses.values()]).sum()
|
|
|
|
stats = {}
|
|
stats |= {k: v.item() for k, v in losses.items()}
|
|
|
|
return loss, stats
|
|
|
|
@torch.inference_mode()
|
|
def run_eval(engines, eval_name, dl):
|
|
engines_stats = {
|
|
'eval': eval_name
|
|
}
|
|
|
|
model = None
|
|
names = []
|
|
for name, engine in engines.items():
|
|
names.append(name)
|
|
model = engine
|
|
break
|
|
|
|
|
|
stats = defaultdict(list)
|
|
stats['loss'] = []
|
|
|
|
def process( name, batch, resps_list ):
|
|
for path, ref, hyp in zip(batch["path"], batch["text"], hyp):
|
|
continue
|
|
|
|
for batch in tqdm(dl):
|
|
batch: dict = to_device(batch, cfg.device)
|
|
|
|
# if we're training both models, provide output for both
|
|
res = model( image=batch['image'], text=batch['text'], sampling_temperature=cfg.evaluation.temperature )
|
|
|
|
for path, ref, hyp in zip(batch["path"], batch["text"], res):
|
|
hyp = hyp.replace('<s>', "").replace("</s>", "")
|
|
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / hyp).with_suffix(".png")
|
|
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
image = Image.open(path).convert('RGB')
|
|
image.save(hyp_path)
|
|
|
|
losses = engine.gather_attribute("loss")
|
|
loss = torch.stack([*losses.values()]).sum().item()
|
|
|
|
stats['loss'].append(loss)
|
|
|
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
|
engines_stats.update(flatten_dict({ name: stats }))
|
|
|
|
iteration = engines.global_step
|
|
engines_stats['it'] = iteration
|
|
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
|
|
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
|
|
|
|
|
def main():
|
|
setup_logging(cfg.log_dir)
|
|
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
|
|
|
def eval_fn(engines):
|
|
try:
|
|
run_eval(engines, "subtrain", subtrain_dl)
|
|
run_eval(engines, "val", val_dl)
|
|
except Exception as e:
|
|
print("Error occurred while performing eval:", str(e))
|
|
print(traceback.format_exc())
|
|
|
|
do_gc()
|
|
|
|
trainer.train(
|
|
train_dl=train_dl,
|
|
train_feeder=train_feeder,
|
|
eval_fn=eval_fn,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
main() |