vall_e.plot tweaks

This commit is contained in:
mrq 2024-09-24 20:05:10 -05:00
parent 2266d34818
commit e84d466261

View File

@ -72,7 +72,10 @@ def plot(paths, args):
if args.max_y is not None:
gdf = gdf[gdf[y] < args.max_y]
gdf[y] = gdf[y].ewm(10).mean()
if args.ewm:
gdf[y] = gdf[y].ewm(args.ewm).mean()
elif args.rolling:
gdf[y] = gdf[y].rolling(args.rolling).mean()
gdf.plot(
x=x,
@ -84,10 +87,10 @@ def plot(paths, args):
)
plt.gca().legend(
loc="center left",
#loc="center left",
fancybox=True,
shadow=True,
bbox_to_anchor=(1.04, 0.5),
#bbox_to_anchor=(1.04, 0.5),
)
@ -102,6 +105,11 @@ if __name__ == "__main__":
parser.add_argument("--max-x", type=float, default=float("inf"))
parser.add_argument("--max-y", type=float, default=float("inf"))
parser.add_argument("--ewm", type=int, default=1024)
parser.add_argument("--rolling", type=int, default=None)
parser.add_argument("--size", type=str, default=None)
parser.add_argument("--filename", default="log.txt")
parser.add_argument("--group-level", default=1)
args, unknown = parser.parse_known_args()
@ -112,7 +120,11 @@ if __name__ == "__main__":
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
if args.ys == "":
args.ys = ["loss"]
args.ys = ["loss.nll"]
if args.size:
width, height = args.size.split("x")
plt.figure(figsize=(int(width), int(height)))
plot(paths, args)