From f7e942ec99acf83c67440117164ce24b3ecf3397 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 2 Sep 2023 13:59:43 -0500 Subject: [PATCH] modified plotting script to be more agnostic to X --- README.md | 4 ++++ scripts/plot.py | 42 ++++++++++++++++++++++++------------------ 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 8ed2ed6..878c84f 100755 --- a/README.md +++ b/README.md @@ -84,6 +84,10 @@ Two dataset formats are supported: - this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths) - be sure to also define `use_hdf5` in your config YAML. +### Plotting Metrics + +Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 ./scripts/plot.py --log-dir ./training/valle/logs/1693675364 --out-dir ./data/ --xs=ar.engine_step --ys=ar.loss` + ### Notices #### Modifying `prom_levels`, `resp_levels`, Or `tasks` For A Model diff --git a/scripts/plot.py b/scripts/plot.py index c47fadc..4064d82 100755 --- a/scripts/plot.py +++ b/scripts/plot.py @@ -26,8 +26,11 @@ def plot(paths, args): except Exception as e: continue - if f"{args.model_name}.engine_step" in row: + for x in args.xs: + if x not in row: + continue rows.append(row) + break df = pd.DataFrame(rows) @@ -44,31 +47,33 @@ def plot(paths, args): df = pd.concat(dfs) if args.max_y is not None: - df = df[df[f"{args.model_name}.engine_step"] < args.max_x] + for x in args.xs: + df = df[df[x] < args.max_x] for gtag, gdf in sorted( df.groupby("group"), key=lambda p: (p[0].split("/")[-1], p[0]), ): - for y in args.ys: - gdf = gdf.sort_values(f"{args.model_name}.engine_step") + for x in args.xs: + for y in args.ys: + gdf = gdf.sort_values(x) - if gdf[y].isna().all(): - continue + if gdf[y].isna().all(): + continue - if args.max_y is not None: - gdf = gdf[gdf[y] < args.max_y] + if args.max_y is not None: + gdf = gdf[gdf[y] < args.max_y] - gdf[y] = gdf[y].ewm(10).mean() + gdf[y] = gdf[y].ewm(10).mean() - gdf.plot( - x=f"{args.model_name}.engine_step", - y=y, - label=f"{y}", - ax=plt.gca(), - marker="x" if len(gdf) < 100 else None, - alpha=0.7, - ) + gdf.plot( + x=x, + y=y, + label=f"{y}", + ax=plt.gca(), + marker="x" if len(gdf) < 100 else None, + alpha=0.7, + ) plt.gca().legend( loc="center left", @@ -80,7 +85,8 @@ def plot(paths, args): def main(): parser = argparse.ArgumentParser() - parser.add_argument("ys", nargs="+") + parser.add_argument("--xs", nargs="+", default="ar.engine_step") + parser.add_argument("--ys", nargs="+", default="ar.loss") parser.add_argument("--log-dir", default="logs", type=Path) parser.add_argument("--out-dir", default="logs", type=Path) parser.add_argument("--filename", default="log.txt")