integrated plot script, added tts-c task token to help the model be able to mix between normal VALL-E and VALL-E continuous
This commit is contained in:
parent
f7e942ec99
commit
4613781e23
|
@ -86,7 +86,9 @@ Two dataset formats are supported:
|
||||||
|
|
||||||
### Plotting Metrics
|
### Plotting Metrics
|
||||||
|
|
||||||
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 ./scripts/plot.py --log-dir ./training/valle/logs/1693675364 --out-dir ./data/ --xs=ar.engine_step --ys=ar.loss`
|
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/valle/config.yaml"`
|
||||||
|
|
||||||
|
You can specify what X and Y labels you want to plot against by passing `--xs tokens_processed --ys loss stats.acc`
|
||||||
|
|
||||||
### Notices
|
### Notices
|
||||||
|
|
||||||
|
|
113
scripts/plot.py
113
scripts/plot.py
|
@ -1,113 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
|
||||||
def plot(paths, args):
|
|
||||||
dfs = []
|
|
||||||
|
|
||||||
for path in paths:
|
|
||||||
with open(path, "r") as f:
|
|
||||||
text = f.read()
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
|
|
||||||
pattern = r"(\{.+?\})\.\n"
|
|
||||||
|
|
||||||
for row in re.findall(pattern, text, re.DOTALL):
|
|
||||||
try:
|
|
||||||
row = json.loads(row)
|
|
||||||
except Exception as e:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for x in args.xs:
|
|
||||||
if x not in row:
|
|
||||||
continue
|
|
||||||
rows.append(row)
|
|
||||||
break
|
|
||||||
|
|
||||||
df = pd.DataFrame(rows)
|
|
||||||
|
|
||||||
if "name" in df:
|
|
||||||
df["name"] = df["name"].fillna("train")
|
|
||||||
else:
|
|
||||||
df["name"] = "train"
|
|
||||||
|
|
||||||
df["group"] = str(path.parents[args.group_level])
|
|
||||||
df["group"] = df["group"] + "/" + df["name"]
|
|
||||||
|
|
||||||
dfs.append(df)
|
|
||||||
|
|
||||||
df = pd.concat(dfs)
|
|
||||||
|
|
||||||
if args.max_y is not None:
|
|
||||||
for x in args.xs:
|
|
||||||
df = df[df[x] < args.max_x]
|
|
||||||
|
|
||||||
for gtag, gdf in sorted(
|
|
||||||
df.groupby("group"),
|
|
||||||
key=lambda p: (p[0].split("/")[-1], p[0]),
|
|
||||||
):
|
|
||||||
for x in args.xs:
|
|
||||||
for y in args.ys:
|
|
||||||
gdf = gdf.sort_values(x)
|
|
||||||
|
|
||||||
if gdf[y].isna().all():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if args.max_y is not None:
|
|
||||||
gdf = gdf[gdf[y] < args.max_y]
|
|
||||||
|
|
||||||
gdf[y] = gdf[y].ewm(10).mean()
|
|
||||||
|
|
||||||
gdf.plot(
|
|
||||||
x=x,
|
|
||||||
y=y,
|
|
||||||
label=f"{y}",
|
|
||||||
ax=plt.gca(),
|
|
||||||
marker="x" if len(gdf) < 100 else None,
|
|
||||||
alpha=0.7,
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.gca().legend(
|
|
||||||
loc="center left",
|
|
||||||
fancybox=True,
|
|
||||||
shadow=True,
|
|
||||||
bbox_to_anchor=(1.04, 0.5),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--xs", nargs="+", default="ar.engine_step")
|
|
||||||
parser.add_argument("--ys", nargs="+", default="ar.loss")
|
|
||||||
parser.add_argument("--log-dir", default="logs", type=Path)
|
|
||||||
parser.add_argument("--out-dir", default="logs", type=Path)
|
|
||||||
parser.add_argument("--filename", default="log.txt")
|
|
||||||
parser.add_argument("--max-x", type=float, default=float("inf"))
|
|
||||||
parser.add_argument("--max-y", type=float, default=float("inf"))
|
|
||||||
parser.add_argument("--group-level", default=1)
|
|
||||||
parser.add_argument("--model-name", type=str, default="ar")
|
|
||||||
parser.add_argument("--filter", default=None)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
paths = args.log_dir.rglob(f"**/{args.filename}")
|
|
||||||
|
|
||||||
if args.filter:
|
|
||||||
paths = filter(lambda p: re.match(".*" + args.filter + ".*", str(p)), paths)
|
|
||||||
|
|
||||||
plot(paths, args)
|
|
||||||
|
|
||||||
name = "-".join(args.ys)
|
|
||||||
out_path = (args.out_dir / name).with_suffix(".png")
|
|
||||||
plt.savefig(out_path, bbox_inches="tight")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -46,6 +46,7 @@ def get_task_symmap():
|
||||||
"<mask>": start + 4,
|
"<mask>": start + 4,
|
||||||
"<eoe>": start + 5,
|
"<eoe>": start + 5,
|
||||||
"<svc>": start + 6,
|
"<svc>": start + 6,
|
||||||
|
"<tts-c>": start + 7,
|
||||||
}
|
}
|
||||||
return symmap
|
return symmap
|
||||||
|
|
||||||
|
@ -58,6 +59,7 @@ def _get_quant_path(path):
|
||||||
def _get_phone_path(path):
|
def _get_phone_path(path):
|
||||||
return _replace_file_extension(path, ".phn.txt")
|
return _replace_file_extension(path, ".phn.txt")
|
||||||
|
|
||||||
|
@cfg.diskcache()
|
||||||
def _load_paths(dataset, type="training"):
|
def _load_paths(dataset, type="training"):
|
||||||
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||||
|
|
||||||
|
@ -317,6 +319,8 @@ class Dataset(_Dataset):
|
||||||
# demote if the target is too short
|
# demote if the target is too short
|
||||||
if task == "tts-c" and trim_length * 2 >= resps.shape[0]:
|
if task == "tts-c" and trim_length * 2 >= resps.shape[0]:
|
||||||
task = "tts"
|
task = "tts"
|
||||||
|
|
||||||
|
task = "tts"
|
||||||
|
|
||||||
# VALL-E continuous
|
# VALL-E continuous
|
||||||
# ignore if target utterance is shorter than prompt duration
|
# ignore if target utterance is shorter than prompt duration
|
||||||
|
@ -324,6 +328,8 @@ class Dataset(_Dataset):
|
||||||
if task == "tts-c":
|
if task == "tts-c":
|
||||||
proms = resps[:trim_length, :]
|
proms = resps[:trim_length, :]
|
||||||
resps = resps[trim_length:, :]
|
resps = resps[trim_length:, :]
|
||||||
|
|
||||||
|
proms = torch.cat( [self.get_task_token(task), proms] )
|
||||||
else:
|
else:
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
# noise suppression || speech removal
|
# noise suppression || speech removal
|
||||||
|
@ -536,7 +542,6 @@ def _create_dataloader(dataset, training):
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cfg.diskcache()
|
|
||||||
def create_datasets():
|
def create_datasets():
|
||||||
train_dataset = Dataset( training=True )
|
train_dataset = Dataset( training=True )
|
||||||
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||||
|
|
112
vall_e/plot.py
Normal file
112
vall_e/plot.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from .config import cfg
|
||||||
|
|
||||||
|
def plot(paths, args):
|
||||||
|
dfs = []
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
pattern = r"(\{.+?\})\.\n"
|
||||||
|
|
||||||
|
for row in re.findall(pattern, text, re.DOTALL):
|
||||||
|
try:
|
||||||
|
row = json.loads(row)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
if f'{model.name}.{args.xs}' not in row:
|
||||||
|
continue
|
||||||
|
rows.append(row)
|
||||||
|
break
|
||||||
|
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
|
||||||
|
if "name" in df:
|
||||||
|
df["name"] = df["name"].fillna("train")
|
||||||
|
else:
|
||||||
|
df["name"] = "train"
|
||||||
|
|
||||||
|
df["group"] = str(path.parents[args.group_level])
|
||||||
|
df["group"] = df["group"] + "/" + df["name"]
|
||||||
|
|
||||||
|
dfs.append(df)
|
||||||
|
|
||||||
|
df = pd.concat(dfs)
|
||||||
|
|
||||||
|
if args.max_y is not None:
|
||||||
|
for model in args.models:
|
||||||
|
df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
|
||||||
|
|
||||||
|
for gtag, gdf in sorted(
|
||||||
|
df.groupby("group"),
|
||||||
|
key=lambda p: (p[0].split("/")[-1], p[0]),
|
||||||
|
):
|
||||||
|
for model in args.models:
|
||||||
|
x = f'{model.name}.{args.xs}'
|
||||||
|
for ys in args.ys:
|
||||||
|
y = f'{model.name}.{ys}'
|
||||||
|
|
||||||
|
if gdf[y].isna().all():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if args.max_y is not None:
|
||||||
|
gdf = gdf[gdf[y] < args.max_y]
|
||||||
|
|
||||||
|
gdf[y] = gdf[y].ewm(10).mean()
|
||||||
|
|
||||||
|
gdf.plot(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
label=f"{y}",
|
||||||
|
ax=plt.gca(),
|
||||||
|
marker="x" if len(gdf) < 100 else None,
|
||||||
|
alpha=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.gca().legend(
|
||||||
|
loc="center left",
|
||||||
|
fancybox=True,
|
||||||
|
shadow=True,
|
||||||
|
bbox_to_anchor=(1.04, 0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--xs", default="engine_step")
|
||||||
|
parser.add_argument("--ys", nargs="+", default="")
|
||||||
|
parser.add_argument("--model", nargs="+", default="*")
|
||||||
|
|
||||||
|
parser.add_argument("--max-x", type=float, default=float("inf"))
|
||||||
|
parser.add_argument("--max-y", type=float, default=float("inf"))
|
||||||
|
|
||||||
|
parser.add_argument("--filename", default="log.txt")
|
||||||
|
parser.add_argument("--group-level", default=1)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
path = cfg.relpath / "logs"
|
||||||
|
paths = path.rglob(f"./*/{args.filename}")
|
||||||
|
|
||||||
|
args.models = [ model for model in cfg.models.get() if model.training and (args.model == "*" or model.name in args.model) ]
|
||||||
|
|
||||||
|
if args.ys == "":
|
||||||
|
args.ys = ["loss"]
|
||||||
|
|
||||||
|
plot(paths, args)
|
||||||
|
|
||||||
|
out_path = cfg.relpath / "metrics.png"
|
||||||
|
plt.savefig(out_path, bbox_inches="tight")
|
Loading…
Reference in New Issue
Block a user