added min-x and min-y arguments to plot.py, helper script to download from my existing checkpoint
This commit is contained in:
parent
777ba43305
commit
153f8b293c
9
scripts/setup.sh
Normal file
9
scripts/setup.sh
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python3 -m venv venv
|
||||||
|
pip3 install -e .
|
||||||
|
|
||||||
|
mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/
|
||||||
|
wget -P ./training/valle/ckpt/ar+nar-retnet-8/fp32.pth "https://huggingface.co/ecker/vall-e/resolve/main/ckpt/ar%2Bnar-retnet-8/fp32.pth"
|
||||||
|
wget -P ./training/valle/data.h5 "https://huggingface.co/ecker/vall-e/resolve/main/data.h5"
|
||||||
|
wget -P ./training/valle/config.yaml "https://huggingface.co/ecker/vall-e/raw/main/config.yaml"
|
|
@ -1,28 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
action = None
|
|
||||||
# copies the resp_embs from a given AR and NAR into an AR as a base to convert into an AR+NAR monolithic odel
|
|
||||||
if action == "merge_resp_embs":
|
|
||||||
src_ar = torch.load("./data/source-ar.pth", map_location="cpu")
|
|
||||||
src_nar = torch.load("./data/source-nar.pth", map_location="cpu")
|
|
||||||
# copies all weights from the AR since the AR is usually "better", might need to experiment more with using a NAR as the base
|
|
||||||
dst = torch.load("./data/source-ar.pth", map_location="cpu")
|
|
||||||
|
|
||||||
# copy resps_emb to layer 0 from AR
|
|
||||||
dst['module']['resps_emb.weight'][:0, :, :] = src_ar['module']['resps_emb.weight']
|
|
||||||
# copy resps_emb to remaining layers from NAR
|
|
||||||
dst['module']['resps_emb.weight'][1:, :-1, :] = src_nar['module']['resps_emb.weight']
|
|
||||||
# copies an existing AR+NAR monolithic model's resp_emb onto an AR
|
|
||||||
elif action == "copy_resps_emb":
|
|
||||||
src = torch.load("./data/source.pth", map_location="cpu")
|
|
||||||
dst = torch.load("./data/destination.pth", map_location="cpu")
|
|
||||||
dst['module']['resps_emb.weight'] = src['module']['resps_emb.weight']
|
|
||||||
elif action == "extend_resps_emb":
|
|
||||||
dst = torch.load("./data/destination.pth", map_location="cpu")
|
|
||||||
dst['module']['resps_emb.weight'] = dst['module']['resps_emb.weight'].expand(4, -1, -1)
|
|
||||||
dst['module']['resps_emb.weight'][1:] = torch.randn(3, 1025, 1024)
|
|
||||||
|
|
||||||
else
|
|
||||||
raise Exception(f"invalid action: {action}")
|
|
||||||
|
|
||||||
torch.save(dst, './data/fp32.pth')
|
|
|
@ -121,7 +121,7 @@ class Dataset:
|
||||||
|
|
||||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||||
|
|
||||||
hdf5_name: str = "data.h5"
|
hdf5_name: str = "dataset.h5"
|
||||||
use_hdf5: bool = False
|
use_hdf5: bool = False
|
||||||
use_metadata: bool = False
|
use_metadata: bool = False
|
||||||
hdf5_flag: str = "a"
|
hdf5_flag: str = "a"
|
||||||
|
|
|
@ -47,7 +47,11 @@ def plot(paths, args):
|
||||||
|
|
||||||
df = pd.concat(dfs)
|
df = pd.concat(dfs)
|
||||||
|
|
||||||
if args.max_y is not None:
|
if args.min_x is not None:
|
||||||
|
for model in args.models:
|
||||||
|
df = df[args.min_x < df[f'{model.name}.{args.xs}']]
|
||||||
|
|
||||||
|
if args.max_x is not None:
|
||||||
for model in args.models:
|
for model in args.models:
|
||||||
df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
|
df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
|
||||||
|
|
||||||
|
@ -63,6 +67,8 @@ def plot(paths, args):
|
||||||
if gdf[y].isna().all():
|
if gdf[y].isna().all():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if args.min_y is not None:
|
||||||
|
gdf = gdf[args.min_y < gdf[y]]
|
||||||
if args.max_y is not None:
|
if args.max_y is not None:
|
||||||
gdf = gdf[gdf[y] < args.max_y]
|
gdf = gdf[gdf[y] < args.max_y]
|
||||||
|
|
||||||
|
@ -91,6 +97,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--ys", nargs="+", default="")
|
parser.add_argument("--ys", nargs="+", default="")
|
||||||
parser.add_argument("--model", nargs="+", default="*")
|
parser.add_argument("--model", nargs="+", default="*")
|
||||||
|
|
||||||
|
parser.add_argument("--min-x", type=float, default=-float("inf"))
|
||||||
|
parser.add_argument("--min-y", type=float, default=-float("inf"))
|
||||||
parser.add_argument("--max-x", type=float, default=float("inf"))
|
parser.add_argument("--max-x", type=float, default=float("inf"))
|
||||||
parser.add_argument("--max-y", type=float, default=float("inf"))
|
parser.add_argument("--max-y", type=float, default=float("inf"))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user