diff --git a/vall_e/plot.py b/vall_e/plot.py index 4eef13f..72f1080 100644 --- a/vall_e/plot.py +++ b/vall_e/plot.py @@ -72,7 +72,10 @@ def plot(paths, args): if args.max_y is not None: gdf = gdf[gdf[y] < args.max_y] - gdf[y] = gdf[y].ewm(10).mean() + if args.ewm: + gdf[y] = gdf[y].ewm(args.ewm).mean() + elif args.rolling: + gdf[y] = gdf[y].rolling(args.rolling).mean() gdf.plot( x=x, @@ -84,10 +87,10 @@ def plot(paths, args): ) plt.gca().legend( - loc="center left", + #loc="center left", fancybox=True, shadow=True, - bbox_to_anchor=(1.04, 0.5), + #bbox_to_anchor=(1.04, 0.5), ) @@ -101,6 +104,11 @@ if __name__ == "__main__": parser.add_argument("--min-y", type=float, default=-float("inf")) parser.add_argument("--max-x", type=float, default=float("inf")) parser.add_argument("--max-y", type=float, default=float("inf")) + + parser.add_argument("--ewm", type=int, default=1024) + parser.add_argument("--rolling", type=int, default=None) + + parser.add_argument("--size", type=str, default=None) parser.add_argument("--filename", default="log.txt") parser.add_argument("--group-level", default=1) @@ -112,7 +120,11 @@ if __name__ == "__main__": args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ] if args.ys == "": - args.ys = ["loss"] + args.ys = ["loss.nll"] + + if args.size: + width, height = args.size.split("x") + plt.figure(figsize=(int(width), int(height))) plot(paths, args)