diff --git a/README.md b/README.md index 878c84f..5ed42ae 100755 --- a/README.md +++ b/README.md @@ -86,7 +86,9 @@ Two dataset formats are supported: ### Plotting Metrics -Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 ./scripts/plot.py --log-dir ./training/valle/logs/1693675364 --out-dir ./data/ --xs=ar.engine_step --ys=ar.loss` +Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/valle/config.yaml"` + +You can specify what X and Y labels you want to plot against by passing `--xs tokens_processed --ys loss stats.acc` ### Notices diff --git a/scripts/plot.py b/scripts/plot.py deleted file mode 100755 index 4064d82..0000000 --- a/scripts/plot.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import json -import re -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd - - -def plot(paths, args): - dfs = [] - - for path in paths: - with open(path, "r") as f: - text = f.read() - - rows = [] - - pattern = r"(\{.+?\})\.\n" - - for row in re.findall(pattern, text, re.DOTALL): - try: - row = json.loads(row) - except Exception as e: - continue - - for x in args.xs: - if x not in row: - continue - rows.append(row) - break - - df = pd.DataFrame(rows) - - if "name" in df: - df["name"] = df["name"].fillna("train") - else: - df["name"] = "train" - - df["group"] = str(path.parents[args.group_level]) - df["group"] = df["group"] + "/" + df["name"] - - dfs.append(df) - - df = pd.concat(dfs) - - if args.max_y is not None: - for x in args.xs: - df = df[df[x] < args.max_x] - - for gtag, gdf in sorted( - df.groupby("group"), - key=lambda p: (p[0].split("/")[-1], p[0]), - ): - for x in args.xs: - for y in args.ys: - gdf = gdf.sort_values(x) - - if gdf[y].isna().all(): - continue - - if args.max_y is not None: - gdf = gdf[gdf[y] < args.max_y] - - gdf[y] = gdf[y].ewm(10).mean() - - gdf.plot( - x=x, - y=y, - label=f"{y}", - ax=plt.gca(), - marker="x" if len(gdf) < 100 else None, - alpha=0.7, - ) - - plt.gca().legend( - loc="center left", - fancybox=True, - shadow=True, - bbox_to_anchor=(1.04, 0.5), - ) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--xs", nargs="+", default="ar.engine_step") - parser.add_argument("--ys", nargs="+", default="ar.loss") - parser.add_argument("--log-dir", default="logs", type=Path) - parser.add_argument("--out-dir", default="logs", type=Path) - parser.add_argument("--filename", default="log.txt") - parser.add_argument("--max-x", type=float, default=float("inf")) - parser.add_argument("--max-y", type=float, default=float("inf")) - parser.add_argument("--group-level", default=1) - parser.add_argument("--model-name", type=str, default="ar") - parser.add_argument("--filter", default=None) - args = parser.parse_args() - - paths = args.log_dir.rglob(f"**/{args.filename}") - - if args.filter: - paths = filter(lambda p: re.match(".*" + args.filter + ".*", str(p)), paths) - - plot(paths, args) - - name = "-".join(args.ys) - out_path = (args.out_dir / name).with_suffix(".png") - plt.savefig(out_path, bbox_inches="tight") - - -if __name__ == "__main__": - main() diff --git a/vall_e/data.py b/vall_e/data.py index 65043d4..84da276 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -46,6 +46,7 @@ def get_task_symmap(): "": start + 4, "": start + 5, "": start + 6, + "": start + 7, } return symmap @@ -58,6 +59,7 @@ def _get_quant_path(path): def _get_phone_path(path): return _replace_file_extension(path, ".phn.txt") +@cfg.diskcache() def _load_paths(dataset, type="training"): return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } @@ -317,6 +319,8 @@ class Dataset(_Dataset): # demote if the target is too short if task == "tts-c" and trim_length * 2 >= resps.shape[0]: task = "tts" + + task = "tts" # VALL-E continuous # ignore if target utterance is shorter than prompt duration @@ -324,6 +328,8 @@ class Dataset(_Dataset): if task == "tts-c": proms = resps[:trim_length, :] resps = resps[trim_length:, :] + + proms = torch.cat( [self.get_task_token(task), proms] ) else: proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps # noise suppression || speech removal @@ -536,7 +542,6 @@ def _create_dataloader(dataset, training): sampler=sampler, ) -@cfg.diskcache() def create_datasets(): train_dataset = Dataset( training=True ) val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False ) diff --git a/vall_e/plot.py b/vall_e/plot.py new file mode 100644 index 0000000..17f0c40 --- /dev/null +++ b/vall_e/plot.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 + +import argparse +import json +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +from .config import cfg + +def plot(paths, args): + dfs = [] + + for path in paths: + with open(path, "r") as f: + text = f.read() + + rows = [] + + pattern = r"(\{.+?\})\.\n" + + for row in re.findall(pattern, text, re.DOTALL): + try: + row = json.loads(row) + except Exception as e: + continue + + for model in args.models: + if f'{model.name}.{args.xs}' not in row: + continue + rows.append(row) + break + + df = pd.DataFrame(rows) + + if "name" in df: + df["name"] = df["name"].fillna("train") + else: + df["name"] = "train" + + df["group"] = str(path.parents[args.group_level]) + df["group"] = df["group"] + "/" + df["name"] + + dfs.append(df) + + df = pd.concat(dfs) + + if args.max_y is not None: + for model in args.models: + df = df[df[f'{model.name}.{args.xs}'] < args.max_x] + + for gtag, gdf in sorted( + df.groupby("group"), + key=lambda p: (p[0].split("/")[-1], p[0]), + ): + for model in args.models: + x = f'{model.name}.{args.xs}' + for ys in args.ys: + y = f'{model.name}.{ys}' + + if gdf[y].isna().all(): + continue + + if args.max_y is not None: + gdf = gdf[gdf[y] < args.max_y] + + gdf[y] = gdf[y].ewm(10).mean() + + gdf.plot( + x=x, + y=y, + label=f"{y}", + ax=plt.gca(), + marker="x" if len(gdf) < 100 else None, + alpha=0.7, + ) + + plt.gca().legend( + loc="center left", + fancybox=True, + shadow=True, + bbox_to_anchor=(1.04, 0.5), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--xs", default="engine_step") + parser.add_argument("--ys", nargs="+", default="") + parser.add_argument("--model", nargs="+", default="*") + + parser.add_argument("--max-x", type=float, default=float("inf")) + parser.add_argument("--max-y", type=float, default=float("inf")) + + parser.add_argument("--filename", default="log.txt") + parser.add_argument("--group-level", default=1) + args = parser.parse_args() + + path = cfg.relpath / "logs" + paths = path.rglob(f"./*/{args.filename}") + + args.models = [ model for model in cfg.models.get() if model.training and (args.model == "*" or model.name in args.model) ] + + if args.ys == "": + args.ys = ["loss"] + + plot(paths, args) + + out_path = cfg.relpath / "metrics.png" + plt.savefig(out_path, bbox_inches="tight") \ No newline at end of file