added --eval-random-text-prompts to use random text prompts for eval pass, added --random-prompts for demo page and --lora to use a sample with the lora disabled, probably finally fixed validation dataloader breaking on eval

This commit is contained in:
mrq 2024-10-10 13:40:25 -05:00
parent 52299127ab
commit 2ea978f318
8 changed files with 161 additions and 82 deletions

View File

@ -96,6 +96,10 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
Some additional flags can be passed as well:
* `--eval`: only run the evaluation / validation pass, then exit afterwards.
* `--eval-random-text-prompts`: use random text prompts for the evaluation pass, rather than the provided text prompts in the dataset.
### Finetuning
Finetuning can be done by training the full model, or using a LoRA.

View File

@ -35,6 +35,72 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__)
@cache
def get_random_prompts( validation=True, length=0, tokenized=False ):
sentences = [
"The birch canoe slid on the smooth planks.",
"Glue the sheet to the dark blue background.",
"It's easy to tell the depth of a well.",
"These days a chicken leg is a rare dish.",
"Rice is often served in round bowls.",
"The juice of lemons makes fine punch.",
"The box was thrown beside the parked truck.",
"The hogs were fed chopped corn and garbage.",
"Four hours of steady work faced us.",
"A large size in stockings is hard to sell.",
"The boy was there when the sun rose.",
"A rod is used to catch pink salmon.",
"The source of the huge river is the clear spring.",
"Kick the ball straight and follow through.",
"Help the woman get back to her feet.",
"A pot of tea helps to pass the evening.",
"Smoky fires lack flame and heat.",
"The soft cushion broke the man's fall.",
"The salt breeze came across from the sea.",
"The girl at the booth sold fifty bonds.",
"The small pup gnawed a hole in the sock.",
"The fish twisted and turned on the bent hook.",
"Press the pants and sew a button on the vest.",
"The swan dive was far short of perfect.",
"The beauty of the view stunned the young boy.",
"Two blue fish swam in the tank.",
"Her purse was full of useless trash.",
"The colt reared and threw the tall rider.",
"It snowed, rained, and hailed the same morning.",
"Read verse out loud for pleasure.",
]
# Pull from validation dataset if existing + requested
if validation and cfg.dataset.validation:
paths = _load_paths(cfg.dataset.validation, type="validation")
paths = list(itertools.chain.from_iterable(paths.values()))
for path in paths:
text_string = ""
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
text_string = metadata["text"] if "text" in metadata else ""
else:
_, metadata = _load_quants(path, return_metadata=True)
text_string = metadata["text"] if "text" in metadata else ""
if len( text_string ) < length:
continue
sentences.append( text_string )
if tokenized:
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]
return sentences
# samples a random text prompt
def get_random_prompt( *args, **kwargs ):
# Harvard sentences
return random.choice(get_random_prompts( *args, **kwargs ))
# fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold_inputs(
text_list = [],
@ -718,7 +784,7 @@ class Dataset(_Dataset):
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
if self.sampler_type == "path":
if self.sampler_type == "path" and self.training:
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler(
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
@ -735,9 +801,7 @@ class Dataset(_Dataset):
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
# loading validation state dict causes issues
if self.dataset_type != "validation":
self.load_state_dict()
self.load_state_dict()
@cached_property
def sampler_state_dict_path(self):
@ -789,6 +853,9 @@ class Dataset(_Dataset):
torch_save(state_dict, path)
def load_state_dict(self, path = None):
if not self.training:
return
if path is None:
path = self.sampler_state_dict_path

View File

@ -26,7 +26,7 @@ from pathlib import Path
from .inference import TTS
from .config import cfg
from .data import create_train_dataloader, create_val_dataloader
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
from .emb.qnt import decode_to_file
from tqdm import tqdm, trange
@ -78,6 +78,9 @@ def main():
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None)
parser.add_argument("--random-prompts", action="store_true")
parser.add_argument("--lora", action="store_true")
args = parser.parse_args()
@ -142,7 +145,7 @@ def main():
metadata = batch["metadata"]
text = metadata["text"]
text = get_random_prompt() if args.random_prompts else metadata["text"]
language = metadata["language"].lower()
prompt = dir / "prompt.wav"
@ -174,8 +177,9 @@ def main():
prompt = dir / "prompt.wav"
reference = dir / "reference.wav"
out_path = dir / "out" / "ours.wav"
out_path_lora = dir / "out" / "ours_lora.wav"
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else []
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
samples.append((
text,
@ -185,27 +189,36 @@ def main():
if args.skip_existing and out_path.exists():
continue
kwargs = dict(
text=text,
references=[prompt],
language=language,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
seed=args.seed,
tqdm=False,
)
if args.lora:
tts.enable_lora()
try:
tts.inference( out_path=out_path_lora, **kwargs )
except Exception as e:
print(f'Error while processing {out_path}: {e}')
tts.disable_lora()
try:
tts.inference(
text=text,
references=[prompt],
language=language,
out_path=out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
seed=args.seed,
tqdm=False,
)
tts.inference( out_path=out_path, **kwargs )
except Exception as e:
print(f'Error while processing {out_path}: {e}')
# collate entries into HTML
samples = [
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+

View File

@ -221,6 +221,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input-speaker", type=Path, default=None)
parser.add_argument("--input-voice", type=str, default=None)
parser.add_argument("--use-dataset", action="store_true")
parser.add_argument("--yaml", type=Path)
@ -254,6 +255,9 @@ def main():
if "LibriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
"""
if args.input_voice and speaker_name != args.input_voice:
return
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
metadata = json_read( metadata_path, default={} )

View File

@ -16,6 +16,7 @@ from .utils import to_device, set_seed, wrapper as ml
from .config import cfg, Config
from .models import get_models
from .models.lora import enable_lora
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
@ -77,6 +78,13 @@ class TTS():
self.symmap = get_phone_symmap()
_logger.info("Loaded model")
def enable_lora( self, enabled=True ):
for name, engine in self.engines.items():
enable_lora( engine.module, mode = enabled )
def disable_lora( self ):
return self.enable_lora( enabled=False )
def encode_text( self, text, language="en" ):
# already a tensor, return it
if isinstance( text, Tensor ):

View File

@ -156,8 +156,10 @@ class AR(Base):
)
# is AR
"""
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) )
"""
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()

View File

@ -1,8 +1,8 @@
# todo: clean this mess up
from .config import cfg
from .data import create_train_val_dataloader
from .emb import qnt
from .data import create_train_val_dataloader, get_random_prompt, tokenize
from .emb import qnt, g2p
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .data import fold_inputs, unfold_outputs
@ -57,10 +57,13 @@ def train_feeder(engine, batch):
return loss, stats
@torch.inference_mode()
def run_eval(engines, eval_name, dl):
def run_eval(engines, eval_name, dl, args=None):
stats = defaultdict(list)
stats['loss'] = []
if cfg.evaluation.size == 0:
return
def process( name, batch, resps_list ):
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
if len(hyp) == 0:
@ -84,16 +87,21 @@ def run_eval(engines, eval_name, dl):
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
if ref is not None:
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
if prom is not None:
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
# naive loss calculation
# to-do: find a better way to calculate this / a better metric
if ref is not None:
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
processed = 0
while processed < cfg.evaluation.size:
@ -105,19 +113,25 @@ def run_eval(engines, eval_name, dl):
batch_size = len(batch["text"])
processed += batch_size
# to-do: eval for text tasks
has_stt = False
for i, task in enumerate( batch["task"] ):
# easier to just change it to a tts task than drop stt tasks from the batch
if task == "stt":
# has_stt = True
batch["task"][i] = "tts"
batch["proms"][i] = batch["resps"][i][:75*3, :]
# random prompts requested
if args and args.eval_random_text_prompts and eval_name == "subtrain":
for i, _ in enumerate(batch["text"]):
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
batch["resps"][i] = None
processed += batch_size
for name in engines:
engine = engines[name]
# to-do: eval for text tasks
has_stt = False
for i, task in enumerate( batch["task"] ):
# easier to just change it to a tts task than drop stt tasks from the batch
if task == "stt":
# has_stt = True
batch["task"][i] = "tts"
batch["proms"][i] = batch["resps"][i][:75*3, :]
kwargs = dict(
text_list=batch["text"],
@ -157,7 +171,7 @@ def run_eval(engines, eval_name, dl):
_logger.info(f"Validation Metrics (STT): {text_list}")
stats = {k: sum(v) / len(v) for k, v in stats.items()}
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = {
f'{name}.{eval_name}': stats,
"it": engines.global_step,
@ -170,6 +184,8 @@ def run_eval(engines, eval_name, dl):
def train():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("--eval", action="store_true", default=None)
parser.add_argument("--eval-random-text-prompts", action="store_true", default=None)
#parser.add_argument("--eval-random-audio-prompts", action="store_true", default=None)
args, unknown = parser.parse_known_args()
# create log folder
@ -185,8 +201,8 @@ def train():
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
try:
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
run_eval(engines, "subtrain", subtrain_dl, args)
run_eval(engines, "val", val_dl, args)
except Exception as e:
_logger.warning(f"Error occurred while performing eval: {str(e)}")
_logger.warning(traceback.format_exc())

View File

@ -19,7 +19,7 @@ from .train import train
from .utils import get_devices, setup_logging, timer
from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave
from .data import get_lang_symmap
from .data import get_lang_symmap, get_random_prompt
tts = None
@ -284,42 +284,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
while True:
metrics = next(it)
yield metrics
"""
def get_random_prompt():
harvard_sentences=[
"The birch canoe slid on the smooth planks.",
"Glue the sheet to the dark blue background.",
"It's easy to tell the depth of a well.",
"These days a chicken leg is a rare dish.",
"Rice is often served in round bowls.",
"The juice of lemons makes fine punch.",
"The box was thrown beside the parked truck.",
"The hogs were fed chopped corn and garbage.",
"Four hours of steady work faced us.",
"A large size in stockings is hard to sell.",
"The boy was there when the sun rose.",
"A rod is used to catch pink salmon.",
"The source of the huge river is the clear spring.",
"Kick the ball straight and follow through.",
"Help the woman get back to her feet.",
"A pot of tea helps to pass the evening.",
"Smoky fires lack flame and heat.",
"The soft cushion broke the man's fall.",
"The salt breeze came across from the sea.",
"The girl at the booth sold fifty bonds.",
"The small pup gnawed a hole in the sock.",
"The fish twisted and turned on the bent hook.",
"Press the pants and sew a button on the vest.",
"The swan dive was far short of perfect.",
"The beauty of the view stunned the young boy.",
"Two blue fish swam in the tank.",
"Her purse was full of useless trash.",
"The colt reared and threw the tall rider.",
"It snowed, rained, and hailed the same morning.",
"Read verse out loud for pleasure.",
]
return random.choice(harvard_sentences)
"""
# setup args
parser = argparse.ArgumentParser(allow_abbrev=False)