modified plotting script to be more agnostic to X
This commit is contained in:
parent
71e68a8528
commit
f7e942ec99
|
@ -84,6 +84,10 @@ Two dataset formats are supported:
|
||||||
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
||||||
- be sure to also define `use_hdf5` in your config YAML.
|
- be sure to also define `use_hdf5` in your config YAML.
|
||||||
|
|
||||||
|
### Plotting Metrics
|
||||||
|
|
||||||
|
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 ./scripts/plot.py --log-dir ./training/valle/logs/1693675364 --out-dir ./data/ --xs=ar.engine_step --ys=ar.loss`
|
||||||
|
|
||||||
### Notices
|
### Notices
|
||||||
|
|
||||||
#### Modifying `prom_levels`, `resp_levels`, Or `tasks` For A Model
|
#### Modifying `prom_levels`, `resp_levels`, Or `tasks` For A Model
|
||||||
|
|
|
@ -26,8 +26,11 @@ def plot(paths, args):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if f"{args.model_name}.engine_step" in row:
|
for x in args.xs:
|
||||||
|
if x not in row:
|
||||||
|
continue
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
break
|
||||||
|
|
||||||
df = pd.DataFrame(rows)
|
df = pd.DataFrame(rows)
|
||||||
|
|
||||||
|
@ -44,31 +47,33 @@ def plot(paths, args):
|
||||||
df = pd.concat(dfs)
|
df = pd.concat(dfs)
|
||||||
|
|
||||||
if args.max_y is not None:
|
if args.max_y is not None:
|
||||||
df = df[df[f"{args.model_name}.engine_step"] < args.max_x]
|
for x in args.xs:
|
||||||
|
df = df[df[x] < args.max_x]
|
||||||
|
|
||||||
for gtag, gdf in sorted(
|
for gtag, gdf in sorted(
|
||||||
df.groupby("group"),
|
df.groupby("group"),
|
||||||
key=lambda p: (p[0].split("/")[-1], p[0]),
|
key=lambda p: (p[0].split("/")[-1], p[0]),
|
||||||
):
|
):
|
||||||
for y in args.ys:
|
for x in args.xs:
|
||||||
gdf = gdf.sort_values(f"{args.model_name}.engine_step")
|
for y in args.ys:
|
||||||
|
gdf = gdf.sort_values(x)
|
||||||
|
|
||||||
if gdf[y].isna().all():
|
if gdf[y].isna().all():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if args.max_y is not None:
|
if args.max_y is not None:
|
||||||
gdf = gdf[gdf[y] < args.max_y]
|
gdf = gdf[gdf[y] < args.max_y]
|
||||||
|
|
||||||
gdf[y] = gdf[y].ewm(10).mean()
|
gdf[y] = gdf[y].ewm(10).mean()
|
||||||
|
|
||||||
gdf.plot(
|
gdf.plot(
|
||||||
x=f"{args.model_name}.engine_step",
|
x=x,
|
||||||
y=y,
|
y=y,
|
||||||
label=f"{y}",
|
label=f"{y}",
|
||||||
ax=plt.gca(),
|
ax=plt.gca(),
|
||||||
marker="x" if len(gdf) < 100 else None,
|
marker="x" if len(gdf) < 100 else None,
|
||||||
alpha=0.7,
|
alpha=0.7,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.gca().legend(
|
plt.gca().legend(
|
||||||
loc="center left",
|
loc="center left",
|
||||||
|
@ -80,7 +85,8 @@ def plot(paths, args):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("ys", nargs="+")
|
parser.add_argument("--xs", nargs="+", default="ar.engine_step")
|
||||||
|
parser.add_argument("--ys", nargs="+", default="ar.loss")
|
||||||
parser.add_argument("--log-dir", default="logs", type=Path)
|
parser.add_argument("--log-dir", default="logs", type=Path)
|
||||||
parser.add_argument("--out-dir", default="logs", type=Path)
|
parser.add_argument("--out-dir", default="logs", type=Path)
|
||||||
parser.add_argument("--filename", default="log.txt")
|
parser.add_argument("--filename", default="log.txt")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user