tweaks and fixes for lora stuffs
This commit is contained in:
parent
54203c059d
commit
31e8b7edb8
17
README.md
17
README.md
|
@ -264,7 +264,8 @@ So far, this only allows you to load a different model without needing to restar
|
|||
* [x] train and release a serviceable model for finetuning against.
|
||||
- LoRA tests shows it's already very capable, although there's room for higher quality (possibly in better NAR training).
|
||||
* [ ] train and release a ***good*** zero-shot model.
|
||||
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
|
||||
- ~~this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.~~
|
||||
- this might need a better training paradigm with providing similar enough input prompts to a given output response.
|
||||
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
|
||||
* [x] ~~explore alternative setups, like a NAR-only model~~
|
||||
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
|
||||
|
@ -272,7 +273,9 @@ So far, this only allows you to load a different model without needing to restar
|
|||
- the AR doesn't *need* exotic sampling techniques, as they're bandaids for a bad AR.
|
||||
- the NAR benefits from greedy sampling, and anything else just harms output quality.
|
||||
* [ ] clean up the README, and document, document, document onto the wiki.
|
||||
* [ ] extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||
* [x] extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)).
|
||||
- [ ] extend the reference model to include at least one other model
|
||||
* [ ] extend to addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||
- `stt` (Speech-to-Text) seems to be working fine for the most part.
|
||||
- other tasks seem to require a ton of VRAM......
|
||||
* [ ] extend using [VALL-E 2](https://arxiv.org/pdf/2406.05370)'s features (grouped code modeling + repetition aware sampling)
|
||||
|
@ -284,6 +287,16 @@ So far, this only allows you to load a different model without needing to restar
|
|||
- espeak is nice, but I can only really put my whole trust with phonemizing English.
|
||||
- a small model trained to handle converting text to phonemes might work, but has it's own problems (another model to carry around, as accurate as the dataset it was trained against, requires training for each language... etc).
|
||||
|
||||
## Caveats
|
||||
|
||||
Despite how lightweight it is in comparison to other TTS's I've meddled with, there are still some caveats, be it with the implementation or model weights:
|
||||
* the audio embeddings have some quirks to having the AR's RVQ level 0 separate from the NAR's RVQ level 0 (sharing them caused some problems in testing)
|
||||
* the trainer / dataloader assumes there are zero variations between a speaker's utterances, and thus it can extract the basics of a speaker's features rather than deeper features (like prosidy, tone, etc.) when performing inferences.
|
||||
+ however, trying to work around this would require training under `tts-c` (VALL-E continuous) mode or modifying an input prompt enough to where its quantized representation differs enough from the output response the prompt derives from.
|
||||
* the trainer's default RVQ level distribution prioritizes lower RVQ levels over higher RVQ levels, as the lower levels contribute to the final waveform more; however, this leaves some minor artifacting that rises in the higher RVQ levels due to inaccuracy issues.
|
||||
* speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling.
|
||||
+ although LoRAs help a ton for fixing results for a single voice.
|
||||
|
||||
## Notices and Citations
|
||||
|
||||
Unless otherwise credited/noted in this README or within the designated Python file, this repository is [licensed](LICENSE) under AGPLv3.
|
||||
|
|
|
@ -734,7 +734,7 @@ class Dataset(_Dataset):
|
|||
|
||||
@cached_property
|
||||
def sampler_state_dict_path(self):
|
||||
return cfg.ckpt_dir / cfg.model.full_name / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
return cfg.ckpt_dir / (cfg.lora.full_name if cfg.lora is not None else cfg.model.full_name) / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
|
||||
def get_speaker(self, path):
|
||||
if isinstance(path, str):
|
||||
|
@ -769,6 +769,9 @@ class Dataset(_Dataset):
|
|||
if path is None:
|
||||
path = self.sampler_state_dict_path
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.get_state()
|
||||
else:
|
||||
|
|
|
@ -382,6 +382,9 @@ class Engines(dict[str, Engine]):
|
|||
continue
|
||||
|
||||
save_dir = cfg.ckpt_dir / name
|
||||
if cfg.lora is not None:
|
||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
try:
|
||||
engine.save_checkpoint(save_dir, tag=tag)
|
||||
except Exception as e:
|
||||
|
|
|
@ -65,7 +65,7 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
|||
if dtype is None:
|
||||
dtype = cfg.inference.dtype
|
||||
|
||||
format = save_path.stem[1:]
|
||||
format = save_path.suffix[1:]
|
||||
|
||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||
# should always be included, but just in case
|
||||
|
@ -136,7 +136,7 @@ def main():
|
|||
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--experts", type=int, default=8) # set target dtype to export to
|
||||
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
|
||||
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
|
||||
parser.add_argument("--format", type=str, default=cfg.weights_format) # set target format to export weights under
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]:
|
||||
|
|
|
@ -1,7 +1,56 @@
|
|||
import logging
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# to-do: implement automatically downloading model
|
||||
DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models'
|
||||
DEFAULT_MODEL_URLS = {
|
||||
'ar+nar-tts+stt-llama-8/fp32.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-tts%2Bstt-llama-8/fp32.sft',
|
||||
}
|
||||
|
||||
# kludge, probably better to use HF's model downloader function
|
||||
# to-do: write to a temp file then copy so downloads can be interrupted
|
||||
def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
||||
scale = 1
|
||||
if unit == "KiB":
|
||||
scale = (1024)
|
||||
elif unit == "MiB":
|
||||
scale = (1024 * 1024)
|
||||
elif unit == "MiB":
|
||||
scale = (1024 * 1024 * 1024)
|
||||
elif unit == "KB":
|
||||
scale = (1000)
|
||||
elif unit == "MB":
|
||||
scale = (1000 * 1000)
|
||||
elif unit == "MB":
|
||||
scale = (1000 * 1000 * 1000)
|
||||
|
||||
name = save_path.name
|
||||
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
|
||||
if url is None:
|
||||
raise Exception(f'Model requested for download but not defined: {name}')
|
||||
|
||||
if not save_path.parent.exists():
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
r = requests.get(url, stream=True)
|
||||
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale
|
||||
|
||||
with open(save_path, 'wb') as f:
|
||||
bar = tqdm( unit=unit, total=content_length )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
bar.update( len(chunk) / scale )
|
||||
f.write(chunk)
|
||||
bar.close()
|
||||
|
||||
|
||||
def get_model(config, training=True, **model_kwargs):
|
||||
name = config.name
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user