culled artifacts left over from the valle trainer

This commit is contained in:
mrq 2023-08-05 04:03:59 +00:00
parent f3f29ddf06
commit b15c0fc33b
7 changed files with 33 additions and 236 deletions

View File

@ -2,6 +2,12 @@
This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/).
## Premise
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
## Training
1. Throw the images you want to train under `./data/images/`.
@ -18,8 +24,32 @@ This is a simple ResNet based image classifier for """specific images""", using
Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0`
## Caveats
## Known Issues
This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born.
* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed.
* The evaluation / validation routine doesn't quite work.
* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed.
This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`.
## Strawmen
>\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!!
I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests.
>\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]`
Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader).
Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there.
>\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`??
I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment.
>\> UGH!!! What are you talking about """specific images"""???
[ひみつ](https://files.catbox.moe/csuh49.webm)
>\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`???
:)

View File

@ -2,7 +2,6 @@
from .config import cfg
from .data import create_train_val_dataloader
from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils.trainer import load_engines

View File

@ -1,9 +0,0 @@
#!/bin/bash
# do not invoke directly in scripts
if [[ ${PWD##*/} == 'scripts' ]]; then
cd ..
fi
# download training data
git clone https://huggingface.co/datasets/ecker/libritts-small ./data/libritts-small

View File

@ -1,106 +0,0 @@
#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
def plot(paths, args):
dfs = []
for path in paths:
with open(path, "r") as f:
text = f.read()
rows = []
pattern = r"(\{.+?\})"
for row in re.findall(pattern, text, re.DOTALL):
try:
row = json.loads(row)
except Exception as e:
continue
if "global_step" in row:
rows.append(row)
df = pd.DataFrame(rows)
if "name" in df:
df["name"] = df["name"].fillna("train")
else:
df["name"] = "train"
df["group"] = str(path.parents[args.group_level])
df["group"] = df["group"] + "/" + df["name"]
dfs.append(df)
df = pd.concat(dfs)
if args.max_y is not None:
df = df[df["global_step"] < args.max_x]
for gtag, gdf in sorted(
df.groupby("group"),
key=lambda p: (p[0].split("/")[-1], p[0]),
):
for y in args.ys:
gdf = gdf.sort_values("global_step")
if gdf[y].isna().all():
continue
if args.max_y is not None:
gdf = gdf[gdf[y] < args.max_y]
gdf[y] = gdf[y].ewm(10).mean()
gdf.plot(
x="global_step",
y=y,
label=f"{gtag}/{y}",
ax=plt.gca(),
marker="x" if len(gdf) < 100 else None,
alpha=0.7,
)
plt.gca().legend(
loc="center left",
fancybox=True,
shadow=True,
bbox_to_anchor=(1.04, 0.5),
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("ys", nargs="+")
parser.add_argument("--log-dir", default="logs", type=Path)
parser.add_argument("--out-dir", default="logs", type=Path)
parser.add_argument("--filename", default="log.txt")
parser.add_argument("--max-x", type=float, default=float("inf"))
parser.add_argument("--max-y", type=float, default=float("inf"))
parser.add_argument("--group-level", default=1)
parser.add_argument("--filter", default=None)
args = parser.parse_args()
paths = args.log_dir.rglob(f"**/{args.filename}")
if args.filter:
paths = filter(lambda p: re.match(".*" + args.filter + ".*", str(p)), paths)
plot(paths, args)
name = "-".join(args.ys)
out_path = (args.out_dir / name).with_suffix(".png")
plt.savefig(out_path, bbox_inches="tight")
if __name__ == "__main__":
main()

View File

@ -1,72 +0,0 @@
import os
import json
for f in os.listdir(f'./data/librispeech_finetuning/1h/'):
for j in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean'):
for z in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean/{j}'):
for i in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean/{j}/{z}'):
os.rename(f'./data/librispeech_finetuning/1h/{f}/clean/{j}/{z}/{i}', f'./data/librilight-tts/{i}')
for j in os.listdir('./data/librispeech_finetuning/9h/clean'):
for z in os.listdir(f'./data/librispeech_finetuning/9h/clean/{j}'):
for i in os.listdir(f'./data/librispeech_finetuning/9h/clean/{j}/{z}'):
os.rename(f'./data/librispeech_finetuning/9h/clean/{j}/{z}/{i}', f'./data/librilight-tts/{i}')
lst = []
for i in os.listdir('./data/librilight-tts/'):
try:
if 'trans' not in i:
continue
with open(f'./data/librilight-tts/{i}') as f:
for row in f:
z = row.split('-')
name = z[0]+'-'+z[1]+ '-' + z[2].split(' ')[0]
text = " ".join(z[2].split(' ')[1:])
lst.append([name, text])
except Exception as e:
pass
for i in lst:
try:
with open(f'./data/librilight-tts/{i[0]}.txt', 'x') as file:
file.write(i[1])
except:
with open(f'./data/librilight-tts/{i[0]}.txt', 'w+') as file:
file.write(i[1])
phoneme_map = {}
phoneme_transcript = {}
with open('./data/librispeech_finetuning/phones/phones_mapping.json', 'r') as f:
phoneme_map_rev = json.load(f)
for k, v in phoneme_map_rev.items():
phoneme_map[f'{v}'] = k
with open('./data/librispeech_finetuning/phones/10h_phones.txt', 'r') as f:
lines = f.readlines()
for line in lines:
split = line.strip().split(" ")
key = split[0]
tokens = split[1:]
phonemes = []
for token in tokens:
phoneme = phoneme_map[f'{token}']
phonemes.append( phoneme )
phoneme_transcript[key] = " ".join(phonemes)
for filename in sorted(os.listdir('./data/librilight-tts')):
split = filename.split('.')
key = split[0]
extension = split[1] # covers double duty of culling .normalized.txt and .phn.txt
if extension != 'txt':
continue
os.rename(f'./data/librilight-tts/{filename}', f'./data/librilight-tts/{key}.normalized.txt')
if key in phoneme_transcript:
with open(f'./data/librilight-tts/{key}.phn.txt', 'w', encoding='utf-8') as f:
f.write(phoneme_transcript[key])

View File

@ -1,27 +0,0 @@
#!/bin/bash
# do not invoke directly in scripts
if [[ ${PWD##*/} == 'scripts' ]]; then
cd ..
fi
# download training data
cd data
mkdir librilight-tts
if [ ! -e ./librispeech_finetuning.tgz ]; then
wget https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz
fi
tar -xzf librispeech_finetuning.tgz
cd ..
# clean it up
python3 ./scripts/prepare_libri.py
# convert to wav
pip3 install AudioConverter
audioconvert convert ./data/librilight-tts/ ./data/librilight-tts --output-format .wav
# process data
ulimit -Sn `ulimit -Hn` # ROCm is a bitch
python3 -m vall_e.emb.g2p ./data/librilight-tts # phonemizes anything that might have been amiss in the phoneme transcription
python3 -m vall_e.emb.qnt ./data/librilight-tts

View File

@ -1,18 +0,0 @@
import os
import json
for f in os.listdir(f'./LibriTTS/'):
if not os.path.isdir(f'./LibriTTS/{f}/'):
continue
for j in os.listdir(f'./LibriTTS/{f}/'):
if not os.path.isdir(f'./LibriTTS/{f}/{j}'):
continue
for z in os.listdir(f'./LibriTTS/{f}/{j}'):
if not os.path.isdir(f'./LibriTTS/{f}/{j}/{z}'):
continue
for i in os.listdir(f'./LibriTTS/{f}/{j}/{z}'):
if i[-4:] != ".wav":
continue
os.makedirs(f'./LibriTTS-Train/{j}/', exist_ok=True)
os.rename(f'./LibriTTS/{f}/{j}/{z}/{i}', f'./LibriTTS-Train/{j}/{i}')