fixed up plot script that I forgot about

This commit is contained in:
mrq 2023-09-02 13:31:04 -05:00
parent 57db3ccfa8
commit 21e5d250cc

View File

@ -18,7 +18,7 @@ def plot(paths, args):
rows = []
pattern = r"(\{.+?\})"
pattern = r"(\{.+?\})\.\n"
for row in re.findall(pattern, text, re.DOTALL):
try:
@ -26,7 +26,7 @@ def plot(paths, args):
except Exception as e:
continue
if "global_step" in row:
if f"{args.model_name}.engine_step" in row:
rows.append(row)
df = pd.DataFrame(rows)
@ -44,14 +44,14 @@ def plot(paths, args):
df = pd.concat(dfs)
if args.max_y is not None:
df = df[df["global_step"] < args.max_x]
df = df[df[f"{args.model_name}.engine_step"] < args.max_x]
for gtag, gdf in sorted(
df.groupby("group"),
key=lambda p: (p[0].split("/")[-1], p[0]),
):
for y in args.ys:
gdf = gdf.sort_values("global_step")
gdf = gdf.sort_values(f"{args.model_name}.engine_step")
if gdf[y].isna().all():
continue
@ -62,9 +62,9 @@ def plot(paths, args):
gdf[y] = gdf[y].ewm(10).mean()
gdf.plot(
x="global_step",
x=f"{args.model_name}.engine_step",
y=y,
label=f"{gtag}/{y}",
label=f"{y}",
ax=plt.gca(),
marker="x" if len(gdf) < 100 else None,
alpha=0.7,
@ -87,6 +87,7 @@ def main():
parser.add_argument("--max-x", type=float, default=float("inf"))
parser.add_argument("--max-y", type=float, default=float("inf"))
parser.add_argument("--group-level", default=1)
parser.add_argument("--model-name", type=str, default="ar")
parser.add_argument("--filter", default=None)
args = parser.parse_args()