tweaks and fixes for lora stuffs

This commit is contained in:
mrq 2024-09-08 18:05:21 -05:00
parent 54203c059d
commit 31e8b7edb8
5 changed files with 73 additions and 5 deletions

View File

@ -264,7 +264,8 @@ So far, this only allows you to load a different model without needing to restar
* [x] train and release a serviceable model for finetuning against.
- LoRA tests shows it's already very capable, although there's room for higher quality (possibly in better NAR training).
* [ ] train and release a ***good*** zero-shot model.
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
- ~~this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.~~
- this might need a better training paradigm with providing similar enough input prompts to a given output response.
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
* [x] ~~explore alternative setups, like a NAR-only model~~
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
@ -272,7 +273,9 @@ So far, this only allows you to load a different model without needing to restar
- the AR doesn't *need* exotic sampling techniques, as they're bandaids for a bad AR.
- the NAR benefits from greedy sampling, and anything else just harms output quality.
* [ ] clean up the README, and document, document, document onto the wiki.
* [ ] extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
* [x] extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)).
- [ ] extend the reference model to include at least one other model
* [ ] extend to addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- `stt` (Speech-to-Text) seems to be working fine for the most part.
- other tasks seem to require a ton of VRAM......
* [ ] extend using [VALL-E 2](https://arxiv.org/pdf/2406.05370)'s features (grouped code modeling + repetition aware sampling)
@ -284,6 +287,16 @@ So far, this only allows you to load a different model without needing to restar
- espeak is nice, but I can only really put my whole trust with phonemizing English.
- a small model trained to handle converting text to phonemes might work, but has it's own problems (another model to carry around, as accurate as the dataset it was trained against, requires training for each language... etc).
## Caveats
Despite how lightweight it is in comparison to other TTS's I've meddled with, there are still some caveats, be it with the implementation or model weights:
* the audio embeddings have some quirks to having the AR's RVQ level 0 separate from the NAR's RVQ level 0 (sharing them caused some problems in testing)
* the trainer / dataloader assumes there are zero variations between a speaker's utterances, and thus it can extract the basics of a speaker's features rather than deeper features (like prosidy, tone, etc.) when performing inferences.
+ however, trying to work around this would require training under `tts-c` (VALL-E continuous) mode or modifying an input prompt enough to where its quantized representation differs enough from the output response the prompt derives from.
* the trainer's default RVQ level distribution prioritizes lower RVQ levels over higher RVQ levels, as the lower levels contribute to the final waveform more; however, this leaves some minor artifacting that rises in the higher RVQ levels due to inaccuracy issues.
* speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling.
+ although LoRAs help a ton for fixing results for a single voice.
## Notices and Citations
Unless otherwise credited/noted in this README or within the designated Python file, this repository is [licensed](LICENSE) under AGPLv3.

View File

@ -734,7 +734,7 @@ class Dataset(_Dataset):
@cached_property
def sampler_state_dict_path(self):
return cfg.ckpt_dir / cfg.model.full_name / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
return cfg.ckpt_dir / (cfg.lora.full_name if cfg.lora is not None else cfg.model.full_name) / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
def get_speaker(self, path):
if isinstance(path, str):
@ -769,6 +769,9 @@ class Dataset(_Dataset):
if path is None:
path = self.sampler_state_dict_path
if not path.parent.exists():
path.parent.mkdir(parents=True, exist_ok=True)
if self.sampler_type == "path":
state_dict = self.sampler.get_state()
else:

View File

@ -382,6 +382,9 @@ class Engines(dict[str, Engine]):
continue
save_dir = cfg.ckpt_dir / name
if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name
try:
engine.save_checkpoint(save_dir, tag=tag)
except Exception as e:

View File

@ -65,7 +65,7 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None:
dtype = cfg.inference.dtype
format = save_path.stem[1:]
format = save_path.suffix[1:]
lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case
@ -136,7 +136,7 @@ def main():
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
parser.add_argument("--experts", type=int, default=8) # set target dtype to export to
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
parser.add_argument("--format", type=str, default=cfg.weights_format) # set target format to export weights under
args, unknown = parser.parse_known_args()
if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]:

View File

@ -1,7 +1,56 @@
import logging
import requests
from tqdm import tqdm
from pathlib import Path
_logger = logging.getLogger(__name__)
# to-do: implement automatically downloading model
DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models'
DEFAULT_MODEL_URLS = {
'ar+nar-tts+stt-llama-8/fp32.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-tts%2Bstt-llama-8/fp32.sft',
}
# kludge, probably better to use HF's model downloader function
# to-do: write to a temp file then copy so downloads can be interrupted
def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
scale = 1
if unit == "KiB":
scale = (1024)
elif unit == "MiB":
scale = (1024 * 1024)
elif unit == "MiB":
scale = (1024 * 1024 * 1024)
elif unit == "KB":
scale = (1000)
elif unit == "MB":
scale = (1000 * 1000)
elif unit == "MB":
scale = (1000 * 1000 * 1000)
name = save_path.name
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
if url is None:
raise Exception(f'Model requested for download but not defined: {name}')
if not save_path.parent.exists():
save_path.parent.mkdir(parents=True, exist_ok=True)
r = requests.get(url, stream=True)
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale
with open(save_path, 'wb') as f:
bar = tqdm( unit=unit, total=content_length )
for chunk in r.iter_content(chunk_size=chunkSize):
if not chunk:
continue
bar.update( len(chunk) / scale )
f.write(chunk)
bar.close()
def get_model(config, training=True, **model_kwargs):
name = config.name