fixed up plot script that I forgot about

This commit is contained in:
mrq 2023-09-02 13:31:04 -05:00
parent 57db3ccfa8
commit 21e5d250cc

View File

@ -18,7 +18,7 @@ def plot(paths, args):
rows = [] rows = []
pattern = r"(\{.+?\})" pattern = r"(\{.+?\})\.\n"
for row in re.findall(pattern, text, re.DOTALL): for row in re.findall(pattern, text, re.DOTALL):
try: try:
@ -26,7 +26,7 @@ def plot(paths, args):
except Exception as e: except Exception as e:
continue continue
if "global_step" in row: if f"{args.model_name}.engine_step" in row:
rows.append(row) rows.append(row)
df = pd.DataFrame(rows) df = pd.DataFrame(rows)
@ -44,14 +44,14 @@ def plot(paths, args):
df = pd.concat(dfs) df = pd.concat(dfs)
if args.max_y is not None: if args.max_y is not None:
df = df[df["global_step"] < args.max_x] df = df[df[f"{args.model_name}.engine_step"] < args.max_x]
for gtag, gdf in sorted( for gtag, gdf in sorted(
df.groupby("group"), df.groupby("group"),
key=lambda p: (p[0].split("/")[-1], p[0]), key=lambda p: (p[0].split("/")[-1], p[0]),
): ):
for y in args.ys: for y in args.ys:
gdf = gdf.sort_values("global_step") gdf = gdf.sort_values(f"{args.model_name}.engine_step")
if gdf[y].isna().all(): if gdf[y].isna().all():
continue continue
@ -62,9 +62,9 @@ def plot(paths, args):
gdf[y] = gdf[y].ewm(10).mean() gdf[y] = gdf[y].ewm(10).mean()
gdf.plot( gdf.plot(
x="global_step", x=f"{args.model_name}.engine_step",
y=y, y=y,
label=f"{gtag}/{y}", label=f"{y}",
ax=plt.gca(), ax=plt.gca(),
marker="x" if len(gdf) < 100 else None, marker="x" if len(gdf) < 100 else None,
alpha=0.7, alpha=0.7,
@ -87,6 +87,7 @@ def main():
parser.add_argument("--max-x", type=float, default=float("inf")) parser.add_argument("--max-x", type=float, default=float("inf"))
parser.add_argument("--max-y", type=float, default=float("inf")) parser.add_argument("--max-y", type=float, default=float("inf"))
parser.add_argument("--group-level", default=1) parser.add_argument("--group-level", default=1)
parser.add_argument("--model-name", type=str, default="ar")
parser.add_argument("--filter", default=None) parser.add_argument("--filter", default=None)
args = parser.parse_args() args = parser.parse_args()