added min-x and min-y arguments to plot.py, helper script to download from my existing checkpoint

This commit is contained in:
mrq 2023-10-04 19:41:37 -05:00
parent 777ba43305
commit 153f8b293c
4 changed files with 19 additions and 30 deletions

9
scripts/setup.sh Normal file
View File

@ -0,0 +1,9 @@
#!/bin/bash
python3 -m venv venv
pip3 install -e .
mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/
wget -P ./training/valle/ckpt/ar+nar-retnet-8/fp32.pth "https://huggingface.co/ecker/vall-e/resolve/main/ckpt/ar%2Bnar-retnet-8/fp32.pth"
wget -P ./training/valle/data.h5 "https://huggingface.co/ecker/vall-e/resolve/main/data.h5"
wget -P ./training/valle/config.yaml "https://huggingface.co/ecker/vall-e/raw/main/config.yaml"

View File

@ -1,28 +0,0 @@
import torch
action = None
# copies the resp_embs from a given AR and NAR into an AR as a base to convert into an AR+NAR monolithic odel
if action == "merge_resp_embs":
src_ar = torch.load("./data/source-ar.pth", map_location="cpu")
src_nar = torch.load("./data/source-nar.pth", map_location="cpu")
# copies all weights from the AR since the AR is usually "better", might need to experiment more with using a NAR as the base
dst = torch.load("./data/source-ar.pth", map_location="cpu")
# copy resps_emb to layer 0 from AR
dst['module']['resps_emb.weight'][:0, :, :] = src_ar['module']['resps_emb.weight']
# copy resps_emb to remaining layers from NAR
dst['module']['resps_emb.weight'][1:, :-1, :] = src_nar['module']['resps_emb.weight']
# copies an existing AR+NAR monolithic model's resp_emb onto an AR
elif action == "copy_resps_emb":
src = torch.load("./data/source.pth", map_location="cpu")
dst = torch.load("./data/destination.pth", map_location="cpu")
dst['module']['resps_emb.weight'] = src['module']['resps_emb.weight']
elif action == "extend_resps_emb":
dst = torch.load("./data/destination.pth", map_location="cpu")
dst['module']['resps_emb.weight'] = dst['module']['resps_emb.weight'].expand(4, -1, -1)
dst['module']['resps_emb.weight'][1:] = torch.randn(3, 1025, 1024)
else
raise Exception(f"invalid action: {action}")
torch.save(dst, './data/fp32.pth')

View File

@ -121,7 +121,7 @@ class Dataset:
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
hdf5_name: str = "data.h5"
hdf5_name: str = "dataset.h5"
use_hdf5: bool = False
use_metadata: bool = False
hdf5_flag: str = "a"

View File

@ -47,7 +47,11 @@ def plot(paths, args):
df = pd.concat(dfs)
if args.max_y is not None:
if args.min_x is not None:
for model in args.models:
df = df[args.min_x < df[f'{model.name}.{args.xs}']]
if args.max_x is not None:
for model in args.models:
df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
@ -63,6 +67,8 @@ def plot(paths, args):
if gdf[y].isna().all():
continue
if args.min_y is not None:
gdf = gdf[args.min_y < gdf[y]]
if args.max_y is not None:
gdf = gdf[gdf[y] < args.max_y]
@ -91,6 +97,8 @@ if __name__ == "__main__":
parser.add_argument("--ys", nargs="+", default="")
parser.add_argument("--model", nargs="+", default="*")
parser.add_argument("--min-x", type=float, default=-float("inf"))
parser.add_argument("--min-y", type=float, default=-float("inf"))
parser.add_argument("--max-x", type=float, default=float("inf"))
parser.add_argument("--max-y", type=float, default=float("inf"))