culled artifacts left over from the valle trainer
This commit is contained in:
parent
f3f29ddf06
commit
b15c0fc33b
36
README.md
36
README.md
|
@ -2,6 +2,12 @@
|
|||
|
||||
This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
|
||||
|
||||
## Premise
|
||||
|
||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
|
||||
|
||||
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
|
||||
|
||||
## Training
|
||||
|
||||
1. Throw the images you want to train under `./data/images/`.
|
||||
|
@ -18,8 +24,32 @@ This is a simple ResNet based image classifier for """specific images""", using
|
|||
|
||||
Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
|
||||
|
||||
## Caveats
|
||||
## Known Issues
|
||||
|
||||
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
|
||||
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
|
||||
* The evaluation / validation routine doesn't quite work.
|
||||
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
|
||||
|
||||
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
|
||||
## Strawmen
|
||||
|
||||
>\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!!
|
||||
|
||||
I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests.
|
||||
|
||||
>\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]`
|
||||
|
||||
Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader).
|
||||
|
||||
Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there.
|
||||
|
||||
>\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`??
|
||||
|
||||
I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment.
|
||||
|
||||
>\> UGH!!! What are you talking about """specific images"""???
|
||||
|
||||
[ひみつ](https://files.catbox.moe/csuh49.webm)
|
||||
|
||||
>\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`???
|
||||
|
||||
:)
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from .config import cfg
|
||||
from .data import create_train_val_dataloader
|
||||
from .emb import qnt
|
||||
|
||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||
from .utils.trainer import load_engines
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# do not invoke directly in scripts
|
||||
if [[ ${PWD##*/} == 'scripts' ]]; then
|
||||
cd ..
|
||||
fi
|
||||
|
||||
# download training data
|
||||
git clone https://huggingface.co/datasets/ecker/libritts-small ./data/libritts-small
|
106
scripts/plot.py
106
scripts/plot.py
|
@ -1,106 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def plot(paths, args):
|
||||
dfs = []
|
||||
|
||||
for path in paths:
|
||||
with open(path, "r") as f:
|
||||
text = f.read()
|
||||
|
||||
rows = []
|
||||
|
||||
pattern = r"(\{.+?\})"
|
||||
|
||||
for row in re.findall(pattern, text, re.DOTALL):
|
||||
try:
|
||||
row = json.loads(row)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if "global_step" in row:
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
|
||||
if "name" in df:
|
||||
df["name"] = df["name"].fillna("train")
|
||||
else:
|
||||
df["name"] = "train"
|
||||
|
||||
df["group"] = str(path.parents[args.group_level])
|
||||
df["group"] = df["group"] + "/" + df["name"]
|
||||
|
||||
dfs.append(df)
|
||||
|
||||
df = pd.concat(dfs)
|
||||
|
||||
if args.max_y is not None:
|
||||
df = df[df["global_step"] < args.max_x]
|
||||
|
||||
for gtag, gdf in sorted(
|
||||
df.groupby("group"),
|
||||
key=lambda p: (p[0].split("/")[-1], p[0]),
|
||||
):
|
||||
for y in args.ys:
|
||||
gdf = gdf.sort_values("global_step")
|
||||
|
||||
if gdf[y].isna().all():
|
||||
continue
|
||||
|
||||
if args.max_y is not None:
|
||||
gdf = gdf[gdf[y] < args.max_y]
|
||||
|
||||
gdf[y] = gdf[y].ewm(10).mean()
|
||||
|
||||
gdf.plot(
|
||||
x="global_step",
|
||||
y=y,
|
||||
label=f"{gtag}/{y}",
|
||||
ax=plt.gca(),
|
||||
marker="x" if len(gdf) < 100 else None,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
plt.gca().legend(
|
||||
loc="center left",
|
||||
fancybox=True,
|
||||
shadow=True,
|
||||
bbox_to_anchor=(1.04, 0.5),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("ys", nargs="+")
|
||||
parser.add_argument("--log-dir", default="logs", type=Path)
|
||||
parser.add_argument("--out-dir", default="logs", type=Path)
|
||||
parser.add_argument("--filename", default="log.txt")
|
||||
parser.add_argument("--max-x", type=float, default=float("inf"))
|
||||
parser.add_argument("--max-y", type=float, default=float("inf"))
|
||||
parser.add_argument("--group-level", default=1)
|
||||
parser.add_argument("--filter", default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = args.log_dir.rglob(f"**/{args.filename}")
|
||||
|
||||
if args.filter:
|
||||
paths = filter(lambda p: re.match(".*" + args.filter + ".*", str(p)), paths)
|
||||
|
||||
plot(paths, args)
|
||||
|
||||
name = "-".join(args.ys)
|
||||
out_path = (args.out_dir / name).with_suffix(".png")
|
||||
plt.savefig(out_path, bbox_inches="tight")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,72 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
for f in os.listdir(f'./data/librispeech_finetuning/1h/'):
|
||||
for j in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean'):
|
||||
for z in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean/{j}'):
|
||||
for i in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean/{j}/{z}'):
|
||||
os.rename(f'./data/librispeech_finetuning/1h/{f}/clean/{j}/{z}/{i}', f'./data/librilight-tts/{i}')
|
||||
|
||||
for j in os.listdir('./data/librispeech_finetuning/9h/clean'):
|
||||
for z in os.listdir(f'./data/librispeech_finetuning/9h/clean/{j}'):
|
||||
for i in os.listdir(f'./data/librispeech_finetuning/9h/clean/{j}/{z}'):
|
||||
os.rename(f'./data/librispeech_finetuning/9h/clean/{j}/{z}/{i}', f'./data/librilight-tts/{i}')
|
||||
|
||||
lst = []
|
||||
for i in os.listdir('./data/librilight-tts/'):
|
||||
try:
|
||||
if 'trans' not in i:
|
||||
continue
|
||||
with open(f'./data/librilight-tts/{i}') as f:
|
||||
for row in f:
|
||||
z = row.split('-')
|
||||
name = z[0]+'-'+z[1]+ '-' + z[2].split(' ')[0]
|
||||
text = " ".join(z[2].split(' ')[1:])
|
||||
lst.append([name, text])
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
for i in lst:
|
||||
try:
|
||||
with open(f'./data/librilight-tts/{i[0]}.txt', 'x') as file:
|
||||
file.write(i[1])
|
||||
except:
|
||||
with open(f'./data/librilight-tts/{i[0]}.txt', 'w+') as file:
|
||||
file.write(i[1])
|
||||
|
||||
phoneme_map = {}
|
||||
phoneme_transcript = {}
|
||||
|
||||
with open('./data/librispeech_finetuning/phones/phones_mapping.json', 'r') as f:
|
||||
phoneme_map_rev = json.load(f)
|
||||
for k, v in phoneme_map_rev.items():
|
||||
phoneme_map[f'{v}'] = k
|
||||
|
||||
with open('./data/librispeech_finetuning/phones/10h_phones.txt', 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
split = line.strip().split(" ")
|
||||
key = split[0]
|
||||
tokens = split[1:]
|
||||
|
||||
phonemes = []
|
||||
for token in tokens:
|
||||
phoneme = phoneme_map[f'{token}']
|
||||
phonemes.append( phoneme )
|
||||
|
||||
phoneme_transcript[key] = " ".join(phonemes)
|
||||
|
||||
for filename in sorted(os.listdir('./data/librilight-tts')):
|
||||
split = filename.split('.')
|
||||
|
||||
key = split[0]
|
||||
extension = split[1] # covers double duty of culling .normalized.txt and .phn.txt
|
||||
|
||||
if extension != 'txt':
|
||||
continue
|
||||
|
||||
os.rename(f'./data/librilight-tts/{filename}', f'./data/librilight-tts/{key}.normalized.txt')
|
||||
|
||||
if key in phoneme_transcript:
|
||||
with open(f'./data/librilight-tts/{key}.phn.txt', 'w', encoding='utf-8') as f:
|
||||
f.write(phoneme_transcript[key])
|
|
@ -1,27 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# do not invoke directly in scripts
|
||||
if [[ ${PWD##*/} == 'scripts' ]]; then
|
||||
cd ..
|
||||
fi
|
||||
|
||||
# download training data
|
||||
cd data
|
||||
mkdir librilight-tts
|
||||
if [ ! -e ./librispeech_finetuning.tgz ]; then
|
||||
wget https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz
|
||||
fi
|
||||
tar -xzf librispeech_finetuning.tgz
|
||||
cd ..
|
||||
|
||||
# clean it up
|
||||
python3 ./scripts/prepare_libri.py
|
||||
|
||||
# convert to wav
|
||||
pip3 install AudioConverter
|
||||
audioconvert convert ./data/librilight-tts/ ./data/librilight-tts --output-format .wav
|
||||
|
||||
# process data
|
||||
ulimit -Sn `ulimit -Hn` # ROCm is a bitch
|
||||
python3 -m vall_e.emb.g2p ./data/librilight-tts # phonemizes anything that might have been amiss in the phoneme transcription
|
||||
python3 -m vall_e.emb.qnt ./data/librilight-tts
|
|
@ -1,18 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
for f in os.listdir(f'./LibriTTS/'):
|
||||
if not os.path.isdir(f'./LibriTTS/{f}/'):
|
||||
continue
|
||||
for j in os.listdir(f'./LibriTTS/{f}/'):
|
||||
if not os.path.isdir(f'./LibriTTS/{f}/{j}'):
|
||||
continue
|
||||
for z in os.listdir(f'./LibriTTS/{f}/{j}'):
|
||||
if not os.path.isdir(f'./LibriTTS/{f}/{j}/{z}'):
|
||||
continue
|
||||
for i in os.listdir(f'./LibriTTS/{f}/{j}/{z}'):
|
||||
if i[-4:] != ".wav":
|
||||
continue
|
||||
|
||||
os.makedirs(f'./LibriTTS-Train/{j}/', exist_ok=True)
|
||||
os.rename(f'./LibriTTS/{f}/{j}/{z}/{i}', f'./LibriTTS-Train/{j}/{i}')
|
Loading…
Reference in New Issue
Block a user