From 8fffb94964cf5aa52debd6b337ab54f32c5ae182 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 25 Jun 2024 13:41:29 -0500 Subject: [PATCH] backport fix from tortoise_tts with local trainer + loading state when training lora --- vall_e/__main__.py | 5 +- vall_e/config.py | 11 ++++ vall_e/data.py | 107 +++++++++++++++++++++++++++++++++++++ vall_e/engines/__init__.py | 7 ++- vall_e/inference.py | 13 ++++- vall_e/utils/__init__.py | 1 + vall_e/utils/trainer.py | 4 -- vall_e/utils/utils.py | 11 ++++ 8 files changed, 151 insertions(+), 8 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 5054411..8e6a8a6 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -34,6 +34,8 @@ def main(): parser.add_argument("--mirostat-tau", type=float, default=0) parser.add_argument("--mirostat-eta", type=float, default=0) + + parser.add_argument("--seed", type=int, default=None) parser.add_argument("--device", type=str, default=None) parser.add_argument("--amp", action="store_true") @@ -55,7 +57,8 @@ def main(): repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width, - mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta + mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + seed=args.seed, ) if __name__ == "__main__": diff --git a/vall_e/config.py b/vall_e/config.py index 016029a..8d11bd9 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -8,8 +8,10 @@ import sys import time import argparse import yaml +import random import torch +import numpy as np from dataclasses import asdict, dataclass, field @@ -18,6 +20,15 @@ from pathlib import Path from .utils.distributed import world_size + +def set_seed(seed=None): + if not seed: + seed = time.time() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + @dataclass() class BaseConfig: yaml_path: str | None = None diff --git a/vall_e/data.py b/vall_e/data.py index 4dbe6a7..1800180 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1278,6 +1278,111 @@ def create_dataset_hdf5( skip_existing=True ): hf.create_dataset('symmap', data=json.dumps(symmap)) hf.close() +def transcribe_dataset(): + import os + import json + import torch + import torchaudio + import whisperx + + from tqdm.auto import tqdm + from pathlib import Path + + # to-do: use argparser + batch_size = 16 + device = "cuda" + dtype = "float16" + model_name = "large-v3" + + input_audio = "voices" + output_dataset = "training/metadata" + + skip_existing = True + diarize = False + + # + model = whisperx.load_model(model_name, device, compute_type=dtype) + align_model, align_model_metadata, align_model_language = (None, None, None) + if diarize: + diarize_model = whisperx.DiarizationPipeline(device=device) + else: + diarize_model = None + + def pad(num, zeroes): + return str(num).zfill(zeroes+1) + + for dataset_name in os.listdir(f'./{input_audio}/'): + if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): + continue + + for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"): + if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): + continue + + outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json') + + if outpath.exists(): + metadata = json.loads(open(outpath, 'r', encoding='utf-8').read()) + else: + os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True) + metadata = {} + + for filename in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'), desc=f"Processing speaker: {speaker_id}"): + + if skip_existing and filename in metadata: + continue + + if ".json" in filename: + continue + + inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}' + + if os.path.isdir(inpath): + continue + + metadata[filename] = { + "segments": [], + "language": "", + "text": "", + "start": 0, + "end": 0, + } + + audio = whisperx.load_audio(inpath) + result = model.transcribe(audio, batch_size=batch_size) + language = result["language"] + + if language[:2] not in ["ja"]: + language = "en" + + if align_model_language != language: + tqdm.write(f'Loading language: {language}') + align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device) + align_model_language = language + + result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False) + + metadata[filename]["segments"] = result["segments"] + metadata[filename]["language"] = language + + if diarize_model is not None: + diarize_segments = diarize_model(audio) + result = whisperx.assign_word_speakers(diarize_segments, result) + + text = [] + start = 0 + end = 0 + for segment in result["segments"]: + text.append( segment["text"] ) + start = min( start, segment["start"] ) + end = max( end, segment["end"] ) + + metadata[filename]["text"] = " ".join(text).strip() + metadata[filename]["start"] = start + metadata[filename]["end"] = end + + open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata)) + if __name__ == "__main__": import argparse @@ -1297,6 +1402,8 @@ if __name__ == "__main__": _logger = LoggerOveride() if args.action == "hdf5": + transcribe_dataset() + elif args.action == "hdf5": create_dataset_hdf5() elif args.action == "list-dataset": dataset = [] diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 707c61d..6e9059f 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -116,10 +116,15 @@ def load_engines(training=True): optimizer = None lr_scheduler = None + checkpoint_path = cfg.ckpt_dir / name / "latest" # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present load_path = cfg.ckpt_dir / name / "fp32.pth" - if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): + # actually use the lora-specific checkpoint if available + if cfg.lora is not None: + checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest" + + if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): print("Checkpoint missing, but weights found.") loads_state_dict = True diff --git a/vall_e/inference.py b/vall_e/inference.py index a609dc2..9258986 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -1,6 +1,7 @@ import torch import torchaudio import soundfile +import time from torch import Tensor from einops import rearrange @@ -8,7 +9,7 @@ from pathlib import Path from .emb import g2p, qnt from .emb.qnt import trim, trim_random -from .utils import to_device +from .utils import to_device, set_seed, wrapper as ml from .config import cfg from .models import get_models @@ -133,6 +134,9 @@ class TTS(): beam_width=0, mirostat_tau=0, mirostat_eta=0.1, + + seed = None, + out_path=None ): lines = text.split("\n") @@ -151,10 +155,15 @@ class TTS(): model_len = engine.module if "nar" in engine.hyper_config.capabilities: model_nar = engine.module + + set_seed(seed) for line in lines: if out_path is None: - out_path = f"./data/{cfg.start_time}.wav" + output_dir = Path("./data/results/") + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + out_path = output_dir / f"{time.time()}.wav" prom = self.encode_audio( references, trim_length=input_prompt_length ) phns = self.encode_text( line, language=language ) diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index 96929f3..b2f2ef9 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -7,4 +7,5 @@ from .utils import ( to_device, tree_map, do_gc, + set_seed, ) \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 947d0e6..d190125 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -131,10 +131,6 @@ def train( _logger.info(cfg) """ - # Setup global engines - global _engines - _engines = engines - events = [] eval_fn = global_leader_only(eval_fn) diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index e92239a..ee37edc 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -7,8 +7,11 @@ from .distributed import global_rank, local_rank, global_leader_only import gc import logging import pandas as pd +import numpy as np import re import torch +import random +import time from coloredlogs import ColoredFormatter from logging import StreamHandler @@ -35,6 +38,14 @@ def flatten_dict(d): return records[0] if records else {} +def set_seed(seed=None): + if not seed: + seed = int(time.time()) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + def _get_named_modules(module, attrname): for name, module in module.named_modules(): if hasattr(module, attrname):