vall_e.plot tweaks
This commit is contained in:
parent
2266d34818
commit
e84d466261
|
@ -72,7 +72,10 @@ def plot(paths, args):
|
||||||
if args.max_y is not None:
|
if args.max_y is not None:
|
||||||
gdf = gdf[gdf[y] < args.max_y]
|
gdf = gdf[gdf[y] < args.max_y]
|
||||||
|
|
||||||
gdf[y] = gdf[y].ewm(10).mean()
|
if args.ewm:
|
||||||
|
gdf[y] = gdf[y].ewm(args.ewm).mean()
|
||||||
|
elif args.rolling:
|
||||||
|
gdf[y] = gdf[y].rolling(args.rolling).mean()
|
||||||
|
|
||||||
gdf.plot(
|
gdf.plot(
|
||||||
x=x,
|
x=x,
|
||||||
|
@ -84,10 +87,10 @@ def plot(paths, args):
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.gca().legend(
|
plt.gca().legend(
|
||||||
loc="center left",
|
#loc="center left",
|
||||||
fancybox=True,
|
fancybox=True,
|
||||||
shadow=True,
|
shadow=True,
|
||||||
bbox_to_anchor=(1.04, 0.5),
|
#bbox_to_anchor=(1.04, 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,6 +104,11 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--min-y", type=float, default=-float("inf"))
|
parser.add_argument("--min-y", type=float, default=-float("inf"))
|
||||||
parser.add_argument("--max-x", type=float, default=float("inf"))
|
parser.add_argument("--max-x", type=float, default=float("inf"))
|
||||||
parser.add_argument("--max-y", type=float, default=float("inf"))
|
parser.add_argument("--max-y", type=float, default=float("inf"))
|
||||||
|
|
||||||
|
parser.add_argument("--ewm", type=int, default=1024)
|
||||||
|
parser.add_argument("--rolling", type=int, default=None)
|
||||||
|
|
||||||
|
parser.add_argument("--size", type=str, default=None)
|
||||||
|
|
||||||
parser.add_argument("--filename", default="log.txt")
|
parser.add_argument("--filename", default="log.txt")
|
||||||
parser.add_argument("--group-level", default=1)
|
parser.add_argument("--group-level", default=1)
|
||||||
|
@ -112,7 +120,11 @@ if __name__ == "__main__":
|
||||||
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
|
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
|
||||||
|
|
||||||
if args.ys == "":
|
if args.ys == "":
|
||||||
args.ys = ["loss"]
|
args.ys = ["loss.nll"]
|
||||||
|
|
||||||
|
if args.size:
|
||||||
|
width, height = args.size.split("x")
|
||||||
|
plt.figure(figsize=(int(width), int(height)))
|
||||||
|
|
||||||
plot(paths, args)
|
plot(paths, args)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user