modified plotting script to be more agnostic to X

This commit is contained in:
mrq 2023-09-02 13:59:43 -05:00
parent 71e68a8528
commit f7e942ec99
2 changed files with 28 additions and 18 deletions

View File

@ -84,6 +84,10 @@ Two dataset formats are supported:
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
- be sure to also define `use_hdf5` in your config YAML.
### Plotting Metrics
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 ./scripts/plot.py --log-dir ./training/valle/logs/1693675364 --out-dir ./data/ --xs=ar.engine_step --ys=ar.loss`
### Notices
#### Modifying `prom_levels`, `resp_levels`, Or `tasks` For A Model

View File

@ -26,8 +26,11 @@ def plot(paths, args):
except Exception as e:
continue
if f"{args.model_name}.engine_step" in row:
for x in args.xs:
if x not in row:
continue
rows.append(row)
break
df = pd.DataFrame(rows)
@ -44,14 +47,16 @@ def plot(paths, args):
df = pd.concat(dfs)
if args.max_y is not None:
df = df[df[f"{args.model_name}.engine_step"] < args.max_x]
for x in args.xs:
df = df[df[x] < args.max_x]
for gtag, gdf in sorted(
df.groupby("group"),
key=lambda p: (p[0].split("/")[-1], p[0]),
):
for x in args.xs:
for y in args.ys:
gdf = gdf.sort_values(f"{args.model_name}.engine_step")
gdf = gdf.sort_values(x)
if gdf[y].isna().all():
continue
@ -62,7 +67,7 @@ def plot(paths, args):
gdf[y] = gdf[y].ewm(10).mean()
gdf.plot(
x=f"{args.model_name}.engine_step",
x=x,
y=y,
label=f"{y}",
ax=plt.gca(),
@ -80,7 +85,8 @@ def plot(paths, args):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("ys", nargs="+")
parser.add_argument("--xs", nargs="+", default="ar.engine_step")
parser.add_argument("--ys", nargs="+", default="ar.loss")
parser.add_argument("--log-dir", default="logs", type=Path)
parser.add_argument("--out-dir", default="logs", type=Path)
parser.add_argument("--filename", default="log.txt")