modified plotting script to be more agnostic to X
This commit is contained in:
parent
71e68a8528
commit
f7e942ec99
|
@ -84,6 +84,10 @@ Two dataset formats are supported:
|
|||
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
||||
- be sure to also define `use_hdf5` in your config YAML.
|
||||
|
||||
### Plotting Metrics
|
||||
|
||||
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 ./scripts/plot.py --log-dir ./training/valle/logs/1693675364 --out-dir ./data/ --xs=ar.engine_step --ys=ar.loss`
|
||||
|
||||
### Notices
|
||||
|
||||
#### Modifying `prom_levels`, `resp_levels`, Or `tasks` For A Model
|
||||
|
|
|
@ -26,8 +26,11 @@ def plot(paths, args):
|
|||
except Exception as e:
|
||||
continue
|
||||
|
||||
if f"{args.model_name}.engine_step" in row:
|
||||
for x in args.xs:
|
||||
if x not in row:
|
||||
continue
|
||||
rows.append(row)
|
||||
break
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
|
||||
|
@ -44,31 +47,33 @@ def plot(paths, args):
|
|||
df = pd.concat(dfs)
|
||||
|
||||
if args.max_y is not None:
|
||||
df = df[df[f"{args.model_name}.engine_step"] < args.max_x]
|
||||
for x in args.xs:
|
||||
df = df[df[x] < args.max_x]
|
||||
|
||||
for gtag, gdf in sorted(
|
||||
df.groupby("group"),
|
||||
key=lambda p: (p[0].split("/")[-1], p[0]),
|
||||
):
|
||||
for y in args.ys:
|
||||
gdf = gdf.sort_values(f"{args.model_name}.engine_step")
|
||||
for x in args.xs:
|
||||
for y in args.ys:
|
||||
gdf = gdf.sort_values(x)
|
||||
|
||||
if gdf[y].isna().all():
|
||||
continue
|
||||
if gdf[y].isna().all():
|
||||
continue
|
||||
|
||||
if args.max_y is not None:
|
||||
gdf = gdf[gdf[y] < args.max_y]
|
||||
if args.max_y is not None:
|
||||
gdf = gdf[gdf[y] < args.max_y]
|
||||
|
||||
gdf[y] = gdf[y].ewm(10).mean()
|
||||
gdf[y] = gdf[y].ewm(10).mean()
|
||||
|
||||
gdf.plot(
|
||||
x=f"{args.model_name}.engine_step",
|
||||
y=y,
|
||||
label=f"{y}",
|
||||
ax=plt.gca(),
|
||||
marker="x" if len(gdf) < 100 else None,
|
||||
alpha=0.7,
|
||||
)
|
||||
gdf.plot(
|
||||
x=x,
|
||||
y=y,
|
||||
label=f"{y}",
|
||||
ax=plt.gca(),
|
||||
marker="x" if len(gdf) < 100 else None,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
plt.gca().legend(
|
||||
loc="center left",
|
||||
|
@ -80,7 +85,8 @@ def plot(paths, args):
|
|||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("ys", nargs="+")
|
||||
parser.add_argument("--xs", nargs="+", default="ar.engine_step")
|
||||
parser.add_argument("--ys", nargs="+", default="ar.loss")
|
||||
parser.add_argument("--log-dir", default="logs", type=Path)
|
||||
parser.add_argument("--out-dir", default="logs", type=Path)
|
||||
parser.add_argument("--filename", default="log.txt")
|
||||
|
|
Loading…
Reference in New Issue
Block a user