modified plotting script to be more agnostic to X

This commit is contained in:
mrq 2023-09-02 13:59:43 -05:00
parent 71e68a8528
commit f7e942ec99
2 changed files with 28 additions and 18 deletions

View File

@ -84,6 +84,10 @@ Two dataset formats are supported:
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths) - this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
- be sure to also define `use_hdf5` in your config YAML. - be sure to also define `use_hdf5` in your config YAML.
### Plotting Metrics
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 ./scripts/plot.py --log-dir ./training/valle/logs/1693675364 --out-dir ./data/ --xs=ar.engine_step --ys=ar.loss`
### Notices ### Notices
#### Modifying `prom_levels`, `resp_levels`, Or `tasks` For A Model #### Modifying `prom_levels`, `resp_levels`, Or `tasks` For A Model

View File

@ -26,8 +26,11 @@ def plot(paths, args):
except Exception as e: except Exception as e:
continue continue
if f"{args.model_name}.engine_step" in row: for x in args.xs:
if x not in row:
continue
rows.append(row) rows.append(row)
break
df = pd.DataFrame(rows) df = pd.DataFrame(rows)
@ -44,31 +47,33 @@ def plot(paths, args):
df = pd.concat(dfs) df = pd.concat(dfs)
if args.max_y is not None: if args.max_y is not None:
df = df[df[f"{args.model_name}.engine_step"] < args.max_x] for x in args.xs:
df = df[df[x] < args.max_x]
for gtag, gdf in sorted( for gtag, gdf in sorted(
df.groupby("group"), df.groupby("group"),
key=lambda p: (p[0].split("/")[-1], p[0]), key=lambda p: (p[0].split("/")[-1], p[0]),
): ):
for y in args.ys: for x in args.xs:
gdf = gdf.sort_values(f"{args.model_name}.engine_step") for y in args.ys:
gdf = gdf.sort_values(x)
if gdf[y].isna().all(): if gdf[y].isna().all():
continue continue
if args.max_y is not None: if args.max_y is not None:
gdf = gdf[gdf[y] < args.max_y] gdf = gdf[gdf[y] < args.max_y]
gdf[y] = gdf[y].ewm(10).mean() gdf[y] = gdf[y].ewm(10).mean()
gdf.plot( gdf.plot(
x=f"{args.model_name}.engine_step", x=x,
y=y, y=y,
label=f"{y}", label=f"{y}",
ax=plt.gca(), ax=plt.gca(),
marker="x" if len(gdf) < 100 else None, marker="x" if len(gdf) < 100 else None,
alpha=0.7, alpha=0.7,
) )
plt.gca().legend( plt.gca().legend(
loc="center left", loc="center left",
@ -80,7 +85,8 @@ def plot(paths, args):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("ys", nargs="+") parser.add_argument("--xs", nargs="+", default="ar.engine_step")
parser.add_argument("--ys", nargs="+", default="ar.loss")
parser.add_argument("--log-dir", default="logs", type=Path) parser.add_argument("--log-dir", default="logs", type=Path)
parser.add_argument("--out-dir", default="logs", type=Path) parser.add_argument("--out-dir", default="logs", type=Path)
parser.add_argument("--filename", default="log.txt") parser.add_argument("--filename", default="log.txt")