diff --git a/README.md b/README.md index fa7b963..8e9647c 100755 --- a/README.md +++ b/README.md @@ -2,6 +2,12 @@ This is a simple ResNet based image classifier for """specific images""", using a similar training framework I use to train [VALL-E](https://git.ecker.tech/mrq/vall-e/). +## Premise + +This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born. + +This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`. + ## Training 1. Throw the images you want to train under `./data/images/`. @@ -18,8 +24,32 @@ This is a simple ResNet based image classifier for """specific images""", using Simply invoke the inferencer with the following command: `python3 -m image_classifier "./data/path-to-your-image.png" yaml="./data/config.yaml" --temp=1.0` -## Caveats +## Known Issues -This was cobbled together in a night, partly to test how well my training framework fares when not married to my VALL-E implementation, and partly to solve a problem I have recently faced. Since I've been balls deep in learning the ins and outs of making VALL-E work, why not do the exact opposite (a tiny, image classification model of fixed lengths) to test the framework and my knowledge? Thus, this """ambiguous""" project is born. +* Setting `dataset.workers` higher than 0 will cause issues when using the local engine backend. Use DeepSpeed. +* The evaluation / validation routine doesn't quite work. +* Using `float16` with the local engine backend will cause instability in the losses. Use DeepSpeed. -This is by no ways state of the art, as it just leverages an existing ResNet arch provided by `torchvision`. \ No newline at end of file +## Strawmen + +>\> UGH... Why *another* training framework!!! Just subjugate [DLAS](https://git.ecker.tech/mrq/DL-Art-School) even more!!! + +I want my own code to own. The original VALL-E implementation had a rather nice and clean setup that *mostly* just made sense. DLAS was a nightmare to comb through for the gorillion amounts of models it attests. + +>\> OK. But how do I use it for `[thing that isn't the specific usecase only I know/care about]` + +Simply provide your own symmapping under `./image_classifier/data.py`, and, be sure to set the delimiter (where exactly is an exercise left to the reader). + +Because this is for a ***very specific*** use-case. I don't really care right now to make this a *little* more generalized, despite most of the bits and bobs for it to generalize being there. + +>\> ur `[a slur]` for using a ResNet... why not use `[CRNN / some other meme arch]`?? + +I don't care, I'd rather keep the copypasting from other people's code to a minimum. Lazily adapting my phoneme tokenizer from my VALL-E implementation into something practically fixed length by introducing start/stop tokens should be grounds for me to use a CRNN, or anything recurrent at the very least, but again, I don't care, it just works for my use case at the moment. + +>\> UGH!!! What are you talking about """specific images"""??? + +[ひみつ](https://files.catbox.moe/csuh49.webm) + +>\> NOOOO!!!! WHY AREN'T YOU USING `[cuck license]`??? + +:) \ No newline at end of file diff --git a/image_classifier/train.py b/image_classifier/train.py index 1a02779..075014c 100755 --- a/image_classifier/train.py +++ b/image_classifier/train.py @@ -2,7 +2,6 @@ from .config import cfg from .data import create_train_val_dataloader -from .emb import qnt from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils.trainer import load_engines diff --git a/scripts/download_libritts-small.sh b/scripts/download_libritts-small.sh deleted file mode 100755 index 1f5ca9a..0000000 --- a/scripts/download_libritts-small.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -# do not invoke directly in scripts -if [[ ${PWD##*/} == 'scripts' ]]; then - cd .. -fi - -# download training data -git clone https://huggingface.co/datasets/ecker/libritts-small ./data/libritts-small \ No newline at end of file diff --git a/scripts/plot.py b/scripts/plot.py deleted file mode 100755 index e46ffc2..0000000 --- a/scripts/plot.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import json -import re -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd - - -def plot(paths, args): - dfs = [] - - for path in paths: - with open(path, "r") as f: - text = f.read() - - rows = [] - - pattern = r"(\{.+?\})" - - for row in re.findall(pattern, text, re.DOTALL): - try: - row = json.loads(row) - except Exception as e: - continue - - if "global_step" in row: - rows.append(row) - - df = pd.DataFrame(rows) - - if "name" in df: - df["name"] = df["name"].fillna("train") - else: - df["name"] = "train" - - df["group"] = str(path.parents[args.group_level]) - df["group"] = df["group"] + "/" + df["name"] - - dfs.append(df) - - df = pd.concat(dfs) - - if args.max_y is not None: - df = df[df["global_step"] < args.max_x] - - for gtag, gdf in sorted( - df.groupby("group"), - key=lambda p: (p[0].split("/")[-1], p[0]), - ): - for y in args.ys: - gdf = gdf.sort_values("global_step") - - if gdf[y].isna().all(): - continue - - if args.max_y is not None: - gdf = gdf[gdf[y] < args.max_y] - - gdf[y] = gdf[y].ewm(10).mean() - - gdf.plot( - x="global_step", - y=y, - label=f"{gtag}/{y}", - ax=plt.gca(), - marker="x" if len(gdf) < 100 else None, - alpha=0.7, - ) - - plt.gca().legend( - loc="center left", - fancybox=True, - shadow=True, - bbox_to_anchor=(1.04, 0.5), - ) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("ys", nargs="+") - parser.add_argument("--log-dir", default="logs", type=Path) - parser.add_argument("--out-dir", default="logs", type=Path) - parser.add_argument("--filename", default="log.txt") - parser.add_argument("--max-x", type=float, default=float("inf")) - parser.add_argument("--max-y", type=float, default=float("inf")) - parser.add_argument("--group-level", default=1) - parser.add_argument("--filter", default=None) - args = parser.parse_args() - - paths = args.log_dir.rglob(f"**/{args.filename}") - - if args.filter: - paths = filter(lambda p: re.match(".*" + args.filter + ".*", str(p)), paths) - - plot(paths, args) - - name = "-".join(args.ys) - out_path = (args.out_dir / name).with_suffix(".png") - plt.savefig(out_path, bbox_inches="tight") - - -if __name__ == "__main__": - main() diff --git a/scripts/prepare_libri.py b/scripts/prepare_libri.py deleted file mode 100755 index 0843308..0000000 --- a/scripts/prepare_libri.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -import json - -for f in os.listdir(f'./data/librispeech_finetuning/1h/'): - for j in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean'): - for z in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean/{j}'): - for i in os.listdir(f'./data/librispeech_finetuning/1h/{f}/clean/{j}/{z}'): - os.rename(f'./data/librispeech_finetuning/1h/{f}/clean/{j}/{z}/{i}', f'./data/librilight-tts/{i}') - -for j in os.listdir('./data/librispeech_finetuning/9h/clean'): - for z in os.listdir(f'./data/librispeech_finetuning/9h/clean/{j}'): - for i in os.listdir(f'./data/librispeech_finetuning/9h/clean/{j}/{z}'): - os.rename(f'./data/librispeech_finetuning/9h/clean/{j}/{z}/{i}', f'./data/librilight-tts/{i}') - -lst = [] -for i in os.listdir('./data/librilight-tts/'): - try: - if 'trans' not in i: - continue - with open(f'./data/librilight-tts/{i}') as f: - for row in f: - z = row.split('-') - name = z[0]+'-'+z[1]+ '-' + z[2].split(' ')[0] - text = " ".join(z[2].split(' ')[1:]) - lst.append([name, text]) - except Exception as e: - pass - -for i in lst: - try: - with open(f'./data/librilight-tts/{i[0]}.txt', 'x') as file: - file.write(i[1]) - except: - with open(f'./data/librilight-tts/{i[0]}.txt', 'w+') as file: - file.write(i[1]) - -phoneme_map = {} -phoneme_transcript = {} - -with open('./data/librispeech_finetuning/phones/phones_mapping.json', 'r') as f: - phoneme_map_rev = json.load(f) - for k, v in phoneme_map_rev.items(): - phoneme_map[f'{v}'] = k - -with open('./data/librispeech_finetuning/phones/10h_phones.txt', 'r') as f: - lines = f.readlines() - for line in lines: - split = line.strip().split(" ") - key = split[0] - tokens = split[1:] - - phonemes = [] - for token in tokens: - phoneme = phoneme_map[f'{token}'] - phonemes.append( phoneme ) - - phoneme_transcript[key] = " ".join(phonemes) - -for filename in sorted(os.listdir('./data/librilight-tts')): - split = filename.split('.') - - key = split[0] - extension = split[1] # covers double duty of culling .normalized.txt and .phn.txt - - if extension != 'txt': - continue - - os.rename(f'./data/librilight-tts/{filename}', f'./data/librilight-tts/{key}.normalized.txt') - - if key in phoneme_transcript: - with open(f'./data/librilight-tts/{key}.phn.txt', 'w', encoding='utf-8') as f: - f.write(phoneme_transcript[key]) \ No newline at end of file diff --git a/scripts/prepare_libri.sh b/scripts/prepare_libri.sh deleted file mode 100755 index d044278..0000000 --- a/scripts/prepare_libri.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# do not invoke directly in scripts -if [[ ${PWD##*/} == 'scripts' ]]; then - cd .. -fi - -# download training data -cd data -mkdir librilight-tts -if [ ! -e ./librispeech_finetuning.tgz ]; then - wget https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz -fi -tar -xzf librispeech_finetuning.tgz -cd .. - -# clean it up -python3 ./scripts/prepare_libri.py - -# convert to wav -pip3 install AudioConverter -audioconvert convert ./data/librilight-tts/ ./data/librilight-tts --output-format .wav - -# process data -ulimit -Sn `ulimit -Hn` # ROCm is a bitch -python3 -m vall_e.emb.g2p ./data/librilight-tts # phonemizes anything that might have been amiss in the phoneme transcription -python3 -m vall_e.emb.qnt ./data/librilight-tts \ No newline at end of file diff --git a/scripts/prepare_libritts.py b/scripts/prepare_libritts.py deleted file mode 100755 index c662fe3..0000000 --- a/scripts/prepare_libritts.py +++ /dev/null @@ -1,18 +0,0 @@ -import os -import json - -for f in os.listdir(f'./LibriTTS/'): - if not os.path.isdir(f'./LibriTTS/{f}/'): - continue - for j in os.listdir(f'./LibriTTS/{f}/'): - if not os.path.isdir(f'./LibriTTS/{f}/{j}'): - continue - for z in os.listdir(f'./LibriTTS/{f}/{j}'): - if not os.path.isdir(f'./LibriTTS/{f}/{j}/{z}'): - continue - for i in os.listdir(f'./LibriTTS/{f}/{j}/{z}'): - if i[-4:] != ".wav": - continue - - os.makedirs(f'./LibriTTS-Train/{j}/', exist_ok=True) - os.rename(f'./LibriTTS/{f}/{j}/{z}/{i}', f'./LibriTTS-Train/{j}/{i}') \ No newline at end of file