fixed up plot script that I forgot about
This commit is contained in:
parent
57db3ccfa8
commit
21e5d250cc
|
@ -18,7 +18,7 @@ def plot(paths, args):
|
|||
|
||||
rows = []
|
||||
|
||||
pattern = r"(\{.+?\})"
|
||||
pattern = r"(\{.+?\})\.\n"
|
||||
|
||||
for row in re.findall(pattern, text, re.DOTALL):
|
||||
try:
|
||||
|
@ -26,7 +26,7 @@ def plot(paths, args):
|
|||
except Exception as e:
|
||||
continue
|
||||
|
||||
if "global_step" in row:
|
||||
if f"{args.model_name}.engine_step" in row:
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
|
@ -44,14 +44,14 @@ def plot(paths, args):
|
|||
df = pd.concat(dfs)
|
||||
|
||||
if args.max_y is not None:
|
||||
df = df[df["global_step"] < args.max_x]
|
||||
df = df[df[f"{args.model_name}.engine_step"] < args.max_x]
|
||||
|
||||
for gtag, gdf in sorted(
|
||||
df.groupby("group"),
|
||||
key=lambda p: (p[0].split("/")[-1], p[0]),
|
||||
):
|
||||
for y in args.ys:
|
||||
gdf = gdf.sort_values("global_step")
|
||||
gdf = gdf.sort_values(f"{args.model_name}.engine_step")
|
||||
|
||||
if gdf[y].isna().all():
|
||||
continue
|
||||
|
@ -62,9 +62,9 @@ def plot(paths, args):
|
|||
gdf[y] = gdf[y].ewm(10).mean()
|
||||
|
||||
gdf.plot(
|
||||
x="global_step",
|
||||
x=f"{args.model_name}.engine_step",
|
||||
y=y,
|
||||
label=f"{gtag}/{y}",
|
||||
label=f"{y}",
|
||||
ax=plt.gca(),
|
||||
marker="x" if len(gdf) < 100 else None,
|
||||
alpha=0.7,
|
||||
|
@ -87,6 +87,7 @@ def main():
|
|||
parser.add_argument("--max-x", type=float, default=float("inf"))
|
||||
parser.add_argument("--max-y", type=float, default=float("inf"))
|
||||
parser.add_argument("--group-level", default=1)
|
||||
parser.add_argument("--model-name", type=str, default="ar")
|
||||
parser.add_argument("--filter", default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user