added min-x and min-y arguments to plot.py, helper script to download from my existing checkpoint
This commit is contained in:
parent
777ba43305
commit
153f8b293c
9
scripts/setup.sh
Normal file
9
scripts/setup.sh
Normal file
|
@ -0,0 +1,9 @@
|
|||
#!/bin/bash
|
||||
|
||||
python3 -m venv venv
|
||||
pip3 install -e .
|
||||
|
||||
mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/
|
||||
wget -P ./training/valle/ckpt/ar+nar-retnet-8/fp32.pth "https://huggingface.co/ecker/vall-e/resolve/main/ckpt/ar%2Bnar-retnet-8/fp32.pth"
|
||||
wget -P ./training/valle/data.h5 "https://huggingface.co/ecker/vall-e/resolve/main/data.h5"
|
||||
wget -P ./training/valle/config.yaml "https://huggingface.co/ecker/vall-e/raw/main/config.yaml"
|
|
@ -1,28 +0,0 @@
|
|||
import torch
|
||||
|
||||
action = None
|
||||
# copies the resp_embs from a given AR and NAR into an AR as a base to convert into an AR+NAR monolithic odel
|
||||
if action == "merge_resp_embs":
|
||||
src_ar = torch.load("./data/source-ar.pth", map_location="cpu")
|
||||
src_nar = torch.load("./data/source-nar.pth", map_location="cpu")
|
||||
# copies all weights from the AR since the AR is usually "better", might need to experiment more with using a NAR as the base
|
||||
dst = torch.load("./data/source-ar.pth", map_location="cpu")
|
||||
|
||||
# copy resps_emb to layer 0 from AR
|
||||
dst['module']['resps_emb.weight'][:0, :, :] = src_ar['module']['resps_emb.weight']
|
||||
# copy resps_emb to remaining layers from NAR
|
||||
dst['module']['resps_emb.weight'][1:, :-1, :] = src_nar['module']['resps_emb.weight']
|
||||
# copies an existing AR+NAR monolithic model's resp_emb onto an AR
|
||||
elif action == "copy_resps_emb":
|
||||
src = torch.load("./data/source.pth", map_location="cpu")
|
||||
dst = torch.load("./data/destination.pth", map_location="cpu")
|
||||
dst['module']['resps_emb.weight'] = src['module']['resps_emb.weight']
|
||||
elif action == "extend_resps_emb":
|
||||
dst = torch.load("./data/destination.pth", map_location="cpu")
|
||||
dst['module']['resps_emb.weight'] = dst['module']['resps_emb.weight'].expand(4, -1, -1)
|
||||
dst['module']['resps_emb.weight'][1:] = torch.randn(3, 1025, 1024)
|
||||
|
||||
else
|
||||
raise Exception(f"invalid action: {action}")
|
||||
|
||||
torch.save(dst, './data/fp32.pth')
|
|
@ -121,7 +121,7 @@ class Dataset:
|
|||
|
||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
|
||||
hdf5_name: str = "data.h5"
|
||||
hdf5_name: str = "dataset.h5"
|
||||
use_hdf5: bool = False
|
||||
use_metadata: bool = False
|
||||
hdf5_flag: str = "a"
|
||||
|
|
|
@ -47,7 +47,11 @@ def plot(paths, args):
|
|||
|
||||
df = pd.concat(dfs)
|
||||
|
||||
if args.max_y is not None:
|
||||
if args.min_x is not None:
|
||||
for model in args.models:
|
||||
df = df[args.min_x < df[f'{model.name}.{args.xs}']]
|
||||
|
||||
if args.max_x is not None:
|
||||
for model in args.models:
|
||||
df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
|
||||
|
||||
|
@ -63,6 +67,8 @@ def plot(paths, args):
|
|||
if gdf[y].isna().all():
|
||||
continue
|
||||
|
||||
if args.min_y is not None:
|
||||
gdf = gdf[args.min_y < gdf[y]]
|
||||
if args.max_y is not None:
|
||||
gdf = gdf[gdf[y] < args.max_y]
|
||||
|
||||
|
@ -91,6 +97,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--ys", nargs="+", default="")
|
||||
parser.add_argument("--model", nargs="+", default="*")
|
||||
|
||||
parser.add_argument("--min-x", type=float, default=-float("inf"))
|
||||
parser.add_argument("--min-y", type=float, default=-float("inf"))
|
||||
parser.add_argument("--max-x", type=float, default=float("inf"))
|
||||
parser.add_argument("--max-y", type=float, default=float("inf"))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user