107 lines
2.6 KiB
Python
107 lines
2.6 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
import argparse
|
||
|
import json
|
||
|
import re
|
||
|
from pathlib import Path
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import pandas as pd
|
||
|
|
||
|
|
||
|
def plot(paths, args):
|
||
|
dfs = []
|
||
|
|
||
|
for path in paths:
|
||
|
with open(path, "r") as f:
|
||
|
text = f.read()
|
||
|
|
||
|
rows = []
|
||
|
|
||
|
pattern = r"(\{.+?\})"
|
||
|
|
||
|
for row in re.findall(pattern, text, re.DOTALL):
|
||
|
try:
|
||
|
row = json.loads(row)
|
||
|
except Exception as e:
|
||
|
continue
|
||
|
|
||
|
if "global_step" in row:
|
||
|
rows.append(row)
|
||
|
|
||
|
df = pd.DataFrame(rows)
|
||
|
|
||
|
if "name" in df:
|
||
|
df["name"] = df["name"].fillna("train")
|
||
|
else:
|
||
|
df["name"] = "train"
|
||
|
|
||
|
df["group"] = str(path.parents[args.group_level])
|
||
|
df["group"] = df["group"] + "/" + df["name"]
|
||
|
|
||
|
dfs.append(df)
|
||
|
|
||
|
df = pd.concat(dfs)
|
||
|
|
||
|
if args.max_y is not None:
|
||
|
df = df[df["global_step"] < args.max_x]
|
||
|
|
||
|
for gtag, gdf in sorted(
|
||
|
df.groupby("group"),
|
||
|
key=lambda p: (p[0].split("/")[-1], p[0]),
|
||
|
):
|
||
|
for y in args.ys:
|
||
|
gdf = gdf.sort_values("global_step")
|
||
|
|
||
|
if gdf[y].isna().all():
|
||
|
continue
|
||
|
|
||
|
if args.max_y is not None:
|
||
|
gdf = gdf[gdf[y] < args.max_y]
|
||
|
|
||
|
gdf[y] = gdf[y].ewm(10).mean()
|
||
|
|
||
|
gdf.plot(
|
||
|
x="global_step",
|
||
|
y=y,
|
||
|
label=f"{gtag}/{y}",
|
||
|
ax=plt.gca(),
|
||
|
marker="x" if len(gdf) < 100 else None,
|
||
|
alpha=0.7,
|
||
|
)
|
||
|
|
||
|
plt.gca().legend(
|
||
|
loc="center left",
|
||
|
fancybox=True,
|
||
|
shadow=True,
|
||
|
bbox_to_anchor=(1.04, 0.5),
|
||
|
)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("ys", nargs="+")
|
||
|
parser.add_argument("--log-dir", default="logs", type=Path)
|
||
|
parser.add_argument("--out-dir", default="logs", type=Path)
|
||
|
parser.add_argument("--filename", default="log.txt")
|
||
|
parser.add_argument("--max-x", type=float, default=float("inf"))
|
||
|
parser.add_argument("--max-y", type=float, default=float("inf"))
|
||
|
parser.add_argument("--group-level", default=1)
|
||
|
parser.add_argument("--filter", default=None)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
paths = args.log_dir.rglob(f"**/{args.filename}")
|
||
|
|
||
|
if args.filter:
|
||
|
paths = filter(lambda p: re.match(".*" + args.filter + ".*", str(p)), paths)
|
||
|
|
||
|
plot(paths, args)
|
||
|
|
||
|
name = "-".join(args.ys)
|
||
|
out_path = (args.out_dir / name).with_suffix(".png")
|
||
|
plt.savefig(out_path, bbox_inches="tight")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|