added --eval-random-text-prompts to use random text prompts for eval pass, added --random-prompts for demo page and --lora to use a sample with the lora disabled, probably finally fixed validation dataloader breaking on eval
This commit is contained in:
parent
52299127ab
commit
2ea978f318
|
@ -96,6 +96,10 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
|
|||
|
||||
The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
||||
|
||||
Some additional flags can be passed as well:
|
||||
* `--eval`: only run the evaluation / validation pass, then exit afterwards.
|
||||
* `--eval-random-text-prompts`: use random text prompts for the evaluation pass, rather than the provided text prompts in the dataset.
|
||||
|
||||
### Finetuning
|
||||
|
||||
Finetuning can be done by training the full model, or using a LoRA.
|
||||
|
|
|
@ -35,6 +35,72 @@ from tqdm.auto import tqdm
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@cache
|
||||
def get_random_prompts( validation=True, length=0, tokenized=False ):
|
||||
sentences = [
|
||||
"The birch canoe slid on the smooth planks.",
|
||||
"Glue the sheet to the dark blue background.",
|
||||
"It's easy to tell the depth of a well.",
|
||||
"These days a chicken leg is a rare dish.",
|
||||
"Rice is often served in round bowls.",
|
||||
"The juice of lemons makes fine punch.",
|
||||
"The box was thrown beside the parked truck.",
|
||||
"The hogs were fed chopped corn and garbage.",
|
||||
"Four hours of steady work faced us.",
|
||||
"A large size in stockings is hard to sell.",
|
||||
"The boy was there when the sun rose.",
|
||||
"A rod is used to catch pink salmon.",
|
||||
"The source of the huge river is the clear spring.",
|
||||
"Kick the ball straight and follow through.",
|
||||
"Help the woman get back to her feet.",
|
||||
"A pot of tea helps to pass the evening.",
|
||||
"Smoky fires lack flame and heat.",
|
||||
"The soft cushion broke the man's fall.",
|
||||
"The salt breeze came across from the sea.",
|
||||
"The girl at the booth sold fifty bonds.",
|
||||
"The small pup gnawed a hole in the sock.",
|
||||
"The fish twisted and turned on the bent hook.",
|
||||
"Press the pants and sew a button on the vest.",
|
||||
"The swan dive was far short of perfect.",
|
||||
"The beauty of the view stunned the young boy.",
|
||||
"Two blue fish swam in the tank.",
|
||||
"Her purse was full of useless trash.",
|
||||
"The colt reared and threw the tall rider.",
|
||||
"It snowed, rained, and hailed the same morning.",
|
||||
"Read verse out loud for pleasure.",
|
||||
]
|
||||
|
||||
# Pull from validation dataset if existing + requested
|
||||
if validation and cfg.dataset.validation:
|
||||
paths = _load_paths(cfg.dataset.validation, type="validation")
|
||||
paths = list(itertools.chain.from_iterable(paths.values()))
|
||||
|
||||
for path in paths:
|
||||
text_string = ""
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||
text_string = metadata["text"] if "text" in metadata else ""
|
||||
else:
|
||||
_, metadata = _load_quants(path, return_metadata=True)
|
||||
text_string = metadata["text"] if "text" in metadata else ""
|
||||
|
||||
if len( text_string ) < length:
|
||||
continue
|
||||
|
||||
sentences.append( text_string )
|
||||
|
||||
if tokenized:
|
||||
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]
|
||||
|
||||
return sentences
|
||||
|
||||
# samples a random text prompt
|
||||
def get_random_prompt( *args, **kwargs ):
|
||||
# Harvard sentences
|
||||
return random.choice(get_random_prompts( *args, **kwargs ))
|
||||
|
||||
# fold into a typical LLM sequence (one embedding rather than split embeddings)
|
||||
def fold_inputs(
|
||||
text_list = [],
|
||||
|
@ -718,7 +784,7 @@ class Dataset(_Dataset):
|
|||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
if self.sampler_type == "path":
|
||||
if self.sampler_type == "path" and self.training:
|
||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
self.sampler = BatchedOrderedSampler(
|
||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||
|
@ -735,9 +801,7 @@ class Dataset(_Dataset):
|
|||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
# loading validation state dict causes issues
|
||||
if self.dataset_type != "validation":
|
||||
self.load_state_dict()
|
||||
self.load_state_dict()
|
||||
|
||||
@cached_property
|
||||
def sampler_state_dict_path(self):
|
||||
|
@ -789,6 +853,9 @@ class Dataset(_Dataset):
|
|||
torch_save(state_dict, path)
|
||||
|
||||
def load_state_dict(self, path = None):
|
||||
if not self.training:
|
||||
return
|
||||
|
||||
if path is None:
|
||||
path = self.sampler_state_dict_path
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from pathlib import Path
|
|||
|
||||
from .inference import TTS
|
||||
from .config import cfg
|
||||
from .data import create_train_dataloader, create_val_dataloader
|
||||
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
||||
from .emb.qnt import decode_to_file
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
@ -78,6 +78,9 @@ def main():
|
|||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=None)
|
||||
|
||||
parser.add_argument("--random-prompts", action="store_true")
|
||||
parser.add_argument("--lora", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -142,7 +145,7 @@ def main():
|
|||
|
||||
metadata = batch["metadata"]
|
||||
|
||||
text = metadata["text"]
|
||||
text = get_random_prompt() if args.random_prompts else metadata["text"]
|
||||
language = metadata["language"].lower()
|
||||
|
||||
prompt = dir / "prompt.wav"
|
||||
|
@ -174,8 +177,9 @@ def main():
|
|||
prompt = dir / "prompt.wav"
|
||||
reference = dir / "reference.wav"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
out_path_lora = dir / "out" / "ours_lora.wav"
|
||||
|
||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else []
|
||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
||||
|
||||
samples.append((
|
||||
text,
|
||||
|
@ -185,27 +189,36 @@ def main():
|
|||
if args.skip_existing and out_path.exists():
|
||||
continue
|
||||
|
||||
kwargs = dict(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
language=language,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
seed=args.seed,
|
||||
tqdm=False,
|
||||
)
|
||||
|
||||
if args.lora:
|
||||
tts.enable_lora()
|
||||
try:
|
||||
tts.inference( out_path=out_path_lora, **kwargs )
|
||||
except Exception as e:
|
||||
print(f'Error while processing {out_path}: {e}')
|
||||
tts.disable_lora()
|
||||
try:
|
||||
tts.inference(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
language=language,
|
||||
out_path=out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
seed=args.seed,
|
||||
tqdm=False,
|
||||
)
|
||||
tts.inference( out_path=out_path, **kwargs )
|
||||
except Exception as e:
|
||||
print(f'Error while processing {out_path}: {e}')
|
||||
|
||||
|
||||
# collate entries into HTML
|
||||
samples = [
|
||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||
|
|
|
@ -221,6 +221,7 @@ def main():
|
|||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--input-speaker", type=Path, default=None)
|
||||
parser.add_argument("--input-voice", type=str, default=None)
|
||||
parser.add_argument("--use-dataset", action="store_true")
|
||||
|
||||
parser.add_argument("--yaml", type=Path)
|
||||
|
@ -254,6 +255,9 @@ def main():
|
|||
if "LibriTTS-R" in speaker_name:
|
||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
||||
"""
|
||||
|
||||
if args.input_voice and speaker_name != args.input_voice:
|
||||
return
|
||||
|
||||
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
|
||||
metadata = json_read( metadata_path, default={} )
|
||||
|
|
|
@ -16,6 +16,7 @@ from .utils import to_device, set_seed, wrapper as ml
|
|||
|
||||
from .config import cfg, Config
|
||||
from .models import get_models
|
||||
from .models.lora import enable_lora
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
||||
|
||||
|
@ -77,6 +78,13 @@ class TTS():
|
|||
self.symmap = get_phone_symmap()
|
||||
_logger.info("Loaded model")
|
||||
|
||||
def enable_lora( self, enabled=True ):
|
||||
for name, engine in self.engines.items():
|
||||
enable_lora( engine.module, mode = enabled )
|
||||
|
||||
def disable_lora( self ):
|
||||
return self.enable_lora( enabled=False )
|
||||
|
||||
def encode_text( self, text, language="en" ):
|
||||
# already a tensor, return it
|
||||
if isinstance( text, Tensor ):
|
||||
|
|
|
@ -156,8 +156,10 @@ class AR(Base):
|
|||
)
|
||||
|
||||
# is AR
|
||||
"""
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
||||
"""
|
||||
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# todo: clean this mess up
|
||||
|
||||
from .config import cfg
|
||||
from .data import create_train_val_dataloader
|
||||
from .emb import qnt
|
||||
from .data import create_train_val_dataloader, get_random_prompt, tokenize
|
||||
from .emb import qnt, g2p
|
||||
|
||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||
from .data import fold_inputs, unfold_outputs
|
||||
|
@ -57,10 +57,13 @@ def train_feeder(engine, batch):
|
|||
return loss, stats
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_eval(engines, eval_name, dl):
|
||||
def run_eval(engines, eval_name, dl, args=None):
|
||||
stats = defaultdict(list)
|
||||
stats['loss'] = []
|
||||
|
||||
if cfg.evaluation.size == 0:
|
||||
return
|
||||
|
||||
def process( name, batch, resps_list ):
|
||||
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
|
||||
if len(hyp) == 0:
|
||||
|
@ -84,16 +87,21 @@ def run_eval(engines, eval_name, dl):
|
|||
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||||
|
||||
if ref is not None:
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||
|
||||
if prom is not None:
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||
|
||||
# pseudo loss calculation since we don't get the logits during eval
|
||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
ref_audio = ref_audio[..., 0:min_length]
|
||||
hyp_audio = hyp_audio[..., 0:min_length]
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
|
||||
# naive loss calculation
|
||||
# to-do: find a better way to calculate this / a better metric
|
||||
if ref is not None:
|
||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
ref_audio = ref_audio[..., 0:min_length]
|
||||
hyp_audio = hyp_audio[..., 0:min_length]
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
|
||||
|
||||
processed = 0
|
||||
while processed < cfg.evaluation.size:
|
||||
|
@ -105,19 +113,25 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
batch_size = len(batch["text"])
|
||||
|
||||
processed += batch_size
|
||||
# to-do: eval for text tasks
|
||||
has_stt = False
|
||||
for i, task in enumerate( batch["task"] ):
|
||||
# easier to just change it to a tts task than drop stt tasks from the batch
|
||||
if task == "stt":
|
||||
# has_stt = True
|
||||
batch["task"][i] = "tts"
|
||||
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
||||
|
||||
# random prompts requested
|
||||
if args and args.eval_random_text_prompts and eval_name == "subtrain":
|
||||
for i, _ in enumerate(batch["text"]):
|
||||
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
||||
batch["resps"][i] = None
|
||||
|
||||
processed += batch_size
|
||||
for name in engines:
|
||||
engine = engines[name]
|
||||
|
||||
# to-do: eval for text tasks
|
||||
has_stt = False
|
||||
for i, task in enumerate( batch["task"] ):
|
||||
# easier to just change it to a tts task than drop stt tasks from the batch
|
||||
if task == "stt":
|
||||
# has_stt = True
|
||||
batch["task"][i] = "tts"
|
||||
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
||||
|
||||
kwargs = dict(
|
||||
text_list=batch["text"],
|
||||
|
@ -157,7 +171,7 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
||||
engines_stats = {
|
||||
f'{name}.{eval_name}': stats,
|
||||
"it": engines.global_step,
|
||||
|
@ -170,6 +184,8 @@ def run_eval(engines, eval_name, dl):
|
|||
def train():
|
||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||
parser.add_argument("--eval", action="store_true", default=None)
|
||||
parser.add_argument("--eval-random-text-prompts", action="store_true", default=None)
|
||||
#parser.add_argument("--eval-random-audio-prompts", action="store_true", default=None)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# create log folder
|
||||
|
@ -185,8 +201,8 @@ def train():
|
|||
engines.eval()
|
||||
# wrapped in a try block because it's sometimes prone to breaking
|
||||
try:
|
||||
run_eval(engines, "subtrain", subtrain_dl)
|
||||
run_eval(engines, "val", val_dl)
|
||||
run_eval(engines, "subtrain", subtrain_dl, args)
|
||||
run_eval(engines, "val", val_dl, args)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
||||
_logger.warning(traceback.format_exc())
|
||||
|
|
|
@ -19,7 +19,7 @@ from .train import train
|
|||
from .utils import get_devices, setup_logging, timer
|
||||
from .utils.io import json_read, json_stringify
|
||||
from .emb.qnt import decode_to_wave
|
||||
from .data import get_lang_symmap
|
||||
from .data import get_lang_symmap, get_random_prompt
|
||||
|
||||
tts = None
|
||||
|
||||
|
@ -284,42 +284,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
while True:
|
||||
metrics = next(it)
|
||||
yield metrics
|
||||
"""
|
||||
|
||||
def get_random_prompt():
|
||||
harvard_sentences=[
|
||||
"The birch canoe slid on the smooth planks.",
|
||||
"Glue the sheet to the dark blue background.",
|
||||
"It's easy to tell the depth of a well.",
|
||||
"These days a chicken leg is a rare dish.",
|
||||
"Rice is often served in round bowls.",
|
||||
"The juice of lemons makes fine punch.",
|
||||
"The box was thrown beside the parked truck.",
|
||||
"The hogs were fed chopped corn and garbage.",
|
||||
"Four hours of steady work faced us.",
|
||||
"A large size in stockings is hard to sell.",
|
||||
"The boy was there when the sun rose.",
|
||||
"A rod is used to catch pink salmon.",
|
||||
"The source of the huge river is the clear spring.",
|
||||
"Kick the ball straight and follow through.",
|
||||
"Help the woman get back to her feet.",
|
||||
"A pot of tea helps to pass the evening.",
|
||||
"Smoky fires lack flame and heat.",
|
||||
"The soft cushion broke the man's fall.",
|
||||
"The salt breeze came across from the sea.",
|
||||
"The girl at the booth sold fifty bonds.",
|
||||
"The small pup gnawed a hole in the sock.",
|
||||
"The fish twisted and turned on the bent hook.",
|
||||
"Press the pants and sew a button on the vest.",
|
||||
"The swan dive was far short of perfect.",
|
||||
"The beauty of the view stunned the young boy.",
|
||||
"Two blue fish swam in the tank.",
|
||||
"Her purse was full of useless trash.",
|
||||
"The colt reared and threw the tall rider.",
|
||||
"It snowed, rained, and hailed the same morning.",
|
||||
"Read verse out loud for pleasure.",
|
||||
]
|
||||
return random.choice(harvard_sentences)
|
||||
"""
|
||||
|
||||
# setup args
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
|
|
Loading…
Reference in New Issue
Block a user