From 21e5d250cc33ef187eb11f8b55c83f7389de8b7f Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 2 Sep 2023 13:31:04 -0500 Subject: [PATCH] fixed up plot script that I forgot about --- scripts/plot.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/scripts/plot.py b/scripts/plot.py index e46ffc2..c47fadc 100755 --- a/scripts/plot.py +++ b/scripts/plot.py @@ -18,7 +18,7 @@ def plot(paths, args): rows = [] - pattern = r"(\{.+?\})" + pattern = r"(\{.+?\})\.\n" for row in re.findall(pattern, text, re.DOTALL): try: @@ -26,7 +26,7 @@ def plot(paths, args): except Exception as e: continue - if "global_step" in row: + if f"{args.model_name}.engine_step" in row: rows.append(row) df = pd.DataFrame(rows) @@ -44,14 +44,14 @@ def plot(paths, args): df = pd.concat(dfs) if args.max_y is not None: - df = df[df["global_step"] < args.max_x] + df = df[df[f"{args.model_name}.engine_step"] < args.max_x] for gtag, gdf in sorted( df.groupby("group"), key=lambda p: (p[0].split("/")[-1], p[0]), ): for y in args.ys: - gdf = gdf.sort_values("global_step") + gdf = gdf.sort_values(f"{args.model_name}.engine_step") if gdf[y].isna().all(): continue @@ -62,9 +62,9 @@ def plot(paths, args): gdf[y] = gdf[y].ewm(10).mean() gdf.plot( - x="global_step", + x=f"{args.model_name}.engine_step", y=y, - label=f"{gtag}/{y}", + label=f"{y}", ax=plt.gca(), marker="x" if len(gdf) < 100 else None, alpha=0.7, @@ -87,6 +87,7 @@ def main(): parser.add_argument("--max-x", type=float, default=float("inf")) parser.add_argument("--max-y", type=float, default=float("inf")) parser.add_argument("--group-level", default=1) + parser.add_argument("--model-name", type=str, default="ar") parser.add_argument("--filter", default=None) args = parser.parse_args()