backport fix from tortoise_tts with local trainer + loading state when training lora

This commit is contained in:
mrq 2024-06-25 13:41:29 -05:00
parent 62a53eed64
commit 8fffb94964
8 changed files with 151 additions and 8 deletions

View File

@ -35,6 +35,8 @@ def main():
parser.add_argument("--mirostat-tau", type=float, default=0) parser.add_argument("--mirostat-tau", type=float, default=0)
parser.add_argument("--mirostat-eta", type=float, default=0) parser.add_argument("--mirostat-eta", type=float, default=0)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None) parser.add_argument("--dtype", type=str, default=None)
@ -55,7 +57,8 @@ def main():
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
seed=args.seed,
) )
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -8,8 +8,10 @@ import sys
import time import time
import argparse import argparse
import yaml import yaml
import random
import torch import torch
import numpy as np
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
@ -18,6 +20,15 @@ from pathlib import Path
from .utils.distributed import world_size from .utils.distributed import world_size
def set_seed(seed=None):
if not seed:
seed = time.time()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@dataclass() @dataclass()
class BaseConfig: class BaseConfig:
yaml_path: str | None = None yaml_path: str | None = None

View File

@ -1278,6 +1278,111 @@ def create_dataset_hdf5( skip_existing=True ):
hf.create_dataset('symmap', data=json.dumps(symmap)) hf.create_dataset('symmap', data=json.dumps(symmap))
hf.close() hf.close()
def transcribe_dataset():
import os
import json
import torch
import torchaudio
import whisperx
from tqdm.auto import tqdm
from pathlib import Path
# to-do: use argparser
batch_size = 16
device = "cuda"
dtype = "float16"
model_name = "large-v3"
input_audio = "voices"
output_dataset = "training/metadata"
skip_existing = True
diarize = False
#
model = whisperx.load_model(model_name, device, compute_type=dtype)
align_model, align_model_metadata, align_model_language = (None, None, None)
if diarize:
diarize_model = whisperx.DiarizationPipeline(device=device)
else:
diarize_model = None
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
for dataset_name in os.listdir(f'./{input_audio}/'):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
continue
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
continue
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
if outpath.exists():
metadata = json.loads(open(outpath, 'r', encoding='utf-8').read())
else:
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
metadata = {}
for filename in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'), desc=f"Processing speaker: {speaker_id}"):
if skip_existing and filename in metadata:
continue
if ".json" in filename:
continue
inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}'
if os.path.isdir(inpath):
continue
metadata[filename] = {
"segments": [],
"language": "",
"text": "",
"start": 0,
"end": 0,
}
audio = whisperx.load_audio(inpath)
result = model.transcribe(audio, batch_size=batch_size)
language = result["language"]
if language[:2] not in ["ja"]:
language = "en"
if align_model_language != language:
tqdm.write(f'Loading language: {language}')
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
align_model_language = language
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
metadata[filename]["segments"] = result["segments"]
metadata[filename]["language"] = language
if diarize_model is not None:
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)
text = []
start = 0
end = 0
for segment in result["segments"]:
text.append( segment["text"] )
start = min( start, segment["start"] )
end = max( end, segment["end"] )
metadata[filename]["text"] = " ".join(text).strip()
metadata[filename]["start"] = start
metadata[filename]["end"] = end
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
@ -1297,6 +1402,8 @@ if __name__ == "__main__":
_logger = LoggerOveride() _logger = LoggerOveride()
if args.action == "hdf5": if args.action == "hdf5":
transcribe_dataset()
elif args.action == "hdf5":
create_dataset_hdf5() create_dataset_hdf5()
elif args.action == "list-dataset": elif args.action == "list-dataset":
dataset = [] dataset = []

View File

@ -116,10 +116,15 @@ def load_engines(training=True):
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth" load_path = cfg.ckpt_dir / name / "fp32.pth"
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): # actually use the lora-specific checkpoint if available
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.") print("Checkpoint missing, but weights found.")
loads_state_dict = True loads_state_dict = True

View File

@ -1,6 +1,7 @@
import torch import torch
import torchaudio import torchaudio
import soundfile import soundfile
import time
from torch import Tensor from torch import Tensor
from einops import rearrange from einops import rearrange
@ -8,7 +9,7 @@ from pathlib import Path
from .emb import g2p, qnt from .emb import g2p, qnt
from .emb.qnt import trim, trim_random from .emb.qnt import trim, trim_random
from .utils import to_device from .utils import to_device, set_seed, wrapper as ml
from .config import cfg from .config import cfg
from .models import get_models from .models import get_models
@ -133,6 +134,9 @@ class TTS():
beam_width=0, beam_width=0,
mirostat_tau=0, mirostat_tau=0,
mirostat_eta=0.1, mirostat_eta=0.1,
seed = None,
out_path=None out_path=None
): ):
lines = text.split("\n") lines = text.split("\n")
@ -152,9 +156,14 @@ class TTS():
if "nar" in engine.hyper_config.capabilities: if "nar" in engine.hyper_config.capabilities:
model_nar = engine.module model_nar = engine.module
set_seed(seed)
for line in lines: for line in lines:
if out_path is None: if out_path is None:
out_path = f"./data/{cfg.start_time}.wav" output_dir = Path("./data/results/")
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{time.time()}.wav"
prom = self.encode_audio( references, trim_length=input_prompt_length ) prom = self.encode_audio( references, trim_length=input_prompt_length )
phns = self.encode_text( line, language=language ) phns = self.encode_text( line, language=language )

View File

@ -7,4 +7,5 @@ from .utils import (
to_device, to_device,
tree_map, tree_map,
do_gc, do_gc,
set_seed,
) )

View File

@ -131,10 +131,6 @@ def train(
_logger.info(cfg) _logger.info(cfg)
""" """
# Setup global engines
global _engines
_engines = engines
events = [] events = []
eval_fn = global_leader_only(eval_fn) eval_fn = global_leader_only(eval_fn)

View File

@ -7,8 +7,11 @@ from .distributed import global_rank, local_rank, global_leader_only
import gc import gc
import logging import logging
import pandas as pd import pandas as pd
import numpy as np
import re import re
import torch import torch
import random
import time
from coloredlogs import ColoredFormatter from coloredlogs import ColoredFormatter
from logging import StreamHandler from logging import StreamHandler
@ -35,6 +38,14 @@ def flatten_dict(d):
return records[0] if records else {} return records[0] if records else {}
def set_seed(seed=None):
if not seed:
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def _get_named_modules(module, attrname): def _get_named_modules(module, attrname):
for name, module in module.named_modules(): for name, module in module.named_modules():
if hasattr(module, attrname): if hasattr(module, attrname):