diff --git a/README.md b/README.md
index e8136c3..1c3eb9b 100755
--- a/README.md
+++ b/README.md
@@ -96,6 +96,10 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
+Some additional flags can be passed as well:
+* `--eval`: only run the evaluation / validation pass, then exit afterwards.
+* `--eval-random-text-prompts`: use random text prompts for the evaluation pass, rather than the provided text prompts in the dataset.
+
### Finetuning
Finetuning can be done by training the full model, or using a LoRA.
diff --git a/vall_e/data.py b/vall_e/data.py
index 209e0de..d80b179 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -35,6 +35,72 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__)
+@cache
+def get_random_prompts( validation=True, length=0, tokenized=False ):
+ sentences = [
+ "The birch canoe slid on the smooth planks.",
+ "Glue the sheet to the dark blue background.",
+ "It's easy to tell the depth of a well.",
+ "These days a chicken leg is a rare dish.",
+ "Rice is often served in round bowls.",
+ "The juice of lemons makes fine punch.",
+ "The box was thrown beside the parked truck.",
+ "The hogs were fed chopped corn and garbage.",
+ "Four hours of steady work faced us.",
+ "A large size in stockings is hard to sell.",
+ "The boy was there when the sun rose.",
+ "A rod is used to catch pink salmon.",
+ "The source of the huge river is the clear spring.",
+ "Kick the ball straight and follow through.",
+ "Help the woman get back to her feet.",
+ "A pot of tea helps to pass the evening.",
+ "Smoky fires lack flame and heat.",
+ "The soft cushion broke the man's fall.",
+ "The salt breeze came across from the sea.",
+ "The girl at the booth sold fifty bonds.",
+ "The small pup gnawed a hole in the sock.",
+ "The fish twisted and turned on the bent hook.",
+ "Press the pants and sew a button on the vest.",
+ "The swan dive was far short of perfect.",
+ "The beauty of the view stunned the young boy.",
+ "Two blue fish swam in the tank.",
+ "Her purse was full of useless trash.",
+ "The colt reared and threw the tall rider.",
+ "It snowed, rained, and hailed the same morning.",
+ "Read verse out loud for pleasure.",
+ ]
+
+ # Pull from validation dataset if existing + requested
+ if validation and cfg.dataset.validation:
+ paths = _load_paths(cfg.dataset.validation, type="validation")
+ paths = list(itertools.chain.from_iterable(paths.values()))
+
+ for path in paths:
+ text_string = ""
+ if cfg.dataset.use_hdf5:
+ key = _get_hdf5_path(path)
+
+ metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
+ text_string = metadata["text"] if "text" in metadata else ""
+ else:
+ _, metadata = _load_quants(path, return_metadata=True)
+ text_string = metadata["text"] if "text" in metadata else ""
+
+ if len( text_string ) < length:
+ continue
+
+ sentences.append( text_string )
+
+ if tokenized:
+ return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]
+
+ return sentences
+
+# samples a random text prompt
+def get_random_prompt( *args, **kwargs ):
+ # Harvard sentences
+ return random.choice(get_random_prompts( *args, **kwargs ))
+
# fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold_inputs(
text_list = [],
@@ -718,7 +784,7 @@ class Dataset(_Dataset):
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
- if self.sampler_type == "path":
+ if self.sampler_type == "path" and self.training:
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler(
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
@@ -735,9 +801,7 @@ class Dataset(_Dataset):
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
- # loading validation state dict causes issues
- if self.dataset_type != "validation":
- self.load_state_dict()
+ self.load_state_dict()
@cached_property
def sampler_state_dict_path(self):
@@ -789,6 +853,9 @@ class Dataset(_Dataset):
torch_save(state_dict, path)
def load_state_dict(self, path = None):
+ if not self.training:
+ return
+
if path is None:
path = self.sampler_state_dict_path
diff --git a/vall_e/demo.py b/vall_e/demo.py
index 80da923..1abfd8d 100644
--- a/vall_e/demo.py
+++ b/vall_e/demo.py
@@ -26,7 +26,7 @@ from pathlib import Path
from .inference import TTS
from .config import cfg
-from .data import create_train_dataloader, create_val_dataloader
+from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
from .emb.qnt import decode_to_file
from tqdm import tqdm, trange
@@ -78,6 +78,9 @@ def main():
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None)
+
+ parser.add_argument("--random-prompts", action="store_true")
+ parser.add_argument("--lora", action="store_true")
args = parser.parse_args()
@@ -142,7 +145,7 @@ def main():
metadata = batch["metadata"]
- text = metadata["text"]
+ text = get_random_prompt() if args.random_prompts else metadata["text"]
language = metadata["language"].lower()
prompt = dir / "prompt.wav"
@@ -174,8 +177,9 @@ def main():
prompt = dir / "prompt.wav"
reference = dir / "reference.wav"
out_path = dir / "out" / "ours.wav"
+ out_path_lora = dir / "out" / "ours_lora.wav"
- extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else []
+ extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
samples.append((
text,
@@ -185,27 +189,36 @@ def main():
if args.skip_existing and out_path.exists():
continue
+ kwargs = dict(
+ text=text,
+ references=[prompt],
+ language=language,
+ input_prompt_length=args.input_prompt_length,
+ max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
+ ar_temp=args.ar_temp, nar_temp=args.nar_temp,
+ min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
+ top_p=args.top_p, top_k=args.top_k,
+ repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
+ length_penalty=args.length_penalty,
+ beam_width=args.beam_width,
+ mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
+ seed=args.seed,
+ tqdm=False,
+ )
+
+ if args.lora:
+ tts.enable_lora()
+ try:
+ tts.inference( out_path=out_path_lora, **kwargs )
+ except Exception as e:
+ print(f'Error while processing {out_path}: {e}')
+ tts.disable_lora()
try:
- tts.inference(
- text=text,
- references=[prompt],
- language=language,
- out_path=out_path,
- input_prompt_length=args.input_prompt_length,
- max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
- ar_temp=args.ar_temp, nar_temp=args.nar_temp,
- min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
- top_p=args.top_p, top_k=args.top_k,
- repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
- length_penalty=args.length_penalty,
- beam_width=args.beam_width,
- mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
- seed=args.seed,
- tqdm=False,
- )
+ tts.inference( out_path=out_path, **kwargs )
except Exception as e:
print(f'Error while processing {out_path}: {e}')
+
# collate entries into HTML
samples = [
f'\n\t\t\t
\n\t\t\t\t{text} | '+
diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py
index 8936094..3bd7193 100644
--- a/vall_e/emb/similar.py
+++ b/vall_e/emb/similar.py
@@ -221,6 +221,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input-speaker", type=Path, default=None)
+ parser.add_argument("--input-voice", type=str, default=None)
parser.add_argument("--use-dataset", action="store_true")
parser.add_argument("--yaml", type=Path)
@@ -254,6 +255,9 @@ def main():
if "LibriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
"""
+
+ if args.input_voice and speaker_name != args.input_voice:
+ return
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
metadata = json_read( metadata_path, default={} )
diff --git a/vall_e/inference.py b/vall_e/inference.py
index ce8d1c3..9740c6a 100755
--- a/vall_e/inference.py
+++ b/vall_e/inference.py
@@ -16,6 +16,7 @@ from .utils import to_device, set_seed, wrapper as ml
from .config import cfg, Config
from .models import get_models
+from .models.lora import enable_lora
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
@@ -77,6 +78,13 @@ class TTS():
self.symmap = get_phone_symmap()
_logger.info("Loaded model")
+ def enable_lora( self, enabled=True ):
+ for name, engine in self.engines.items():
+ enable_lora( engine.module, mode = enabled )
+
+ def disable_lora( self ):
+ return self.enable_lora( enabled=False )
+
def encode_text( self, text, language="en" ):
# already a tensor, return it
if isinstance( text, Tensor ):
diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py
index 383cba9..28aa653 100644
--- a/vall_e/models/ar.py
+++ b/vall_e/models/ar.py
@@ -156,8 +156,10 @@ class AR(Base):
)
# is AR
+ """
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) )
+ """
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
diff --git a/vall_e/train.py b/vall_e/train.py
index 58811be..9f0b969 100755
--- a/vall_e/train.py
+++ b/vall_e/train.py
@@ -1,8 +1,8 @@
# todo: clean this mess up
from .config import cfg
-from .data import create_train_val_dataloader
-from .emb import qnt
+from .data import create_train_val_dataloader, get_random_prompt, tokenize
+from .emb import qnt, g2p
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .data import fold_inputs, unfold_outputs
@@ -57,10 +57,13 @@ def train_feeder(engine, batch):
return loss, stats
@torch.inference_mode()
-def run_eval(engines, eval_name, dl):
+def run_eval(engines, eval_name, dl, args=None):
stats = defaultdict(list)
stats['loss'] = []
+ if cfg.evaluation.size == 0:
+ return
+
def process( name, batch, resps_list ):
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
if len(hyp) == 0:
@@ -84,16 +87,21 @@ def run_eval(engines, eval_name, dl):
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
- ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
+
+ if ref is not None:
+ ref_audio, sr = qnt.decode_to_file(ref, ref_path)
+
if prom is not None:
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
- # pseudo loss calculation since we don't get the logits during eval
- min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
- ref_audio = ref_audio[..., 0:min_length]
- hyp_audio = hyp_audio[..., 0:min_length]
- stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
+ # naive loss calculation
+ # to-do: find a better way to calculate this / a better metric
+ if ref is not None:
+ min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
+ ref_audio = ref_audio[..., 0:min_length]
+ hyp_audio = hyp_audio[..., 0:min_length]
+ stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
processed = 0
while processed < cfg.evaluation.size:
@@ -105,19 +113,25 @@ def run_eval(engines, eval_name, dl):
batch_size = len(batch["text"])
- processed += batch_size
+ # to-do: eval for text tasks
+ has_stt = False
+ for i, task in enumerate( batch["task"] ):
+ # easier to just change it to a tts task than drop stt tasks from the batch
+ if task == "stt":
+ # has_stt = True
+ batch["task"][i] = "tts"
+ batch["proms"][i] = batch["resps"][i][:75*3, :]
+ # random prompts requested
+ if args and args.eval_random_text_prompts and eval_name == "subtrain":
+ for i, _ in enumerate(batch["text"]):
+ batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
+ batch["resps"][i] = None
+
+ processed += batch_size
for name in engines:
engine = engines[name]
- # to-do: eval for text tasks
- has_stt = False
- for i, task in enumerate( batch["task"] ):
- # easier to just change it to a tts task than drop stt tasks from the batch
- if task == "stt":
- # has_stt = True
- batch["task"][i] = "tts"
- batch["proms"][i] = batch["resps"][i][:75*3, :]
kwargs = dict(
text_list=batch["text"],
@@ -157,7 +171,7 @@ def run_eval(engines, eval_name, dl):
_logger.info(f"Validation Metrics (STT): {text_list}")
- stats = {k: sum(v) / len(v) for k, v in stats.items()}
+ stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = {
f'{name}.{eval_name}': stats,
"it": engines.global_step,
@@ -170,6 +184,8 @@ def run_eval(engines, eval_name, dl):
def train():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("--eval", action="store_true", default=None)
+ parser.add_argument("--eval-random-text-prompts", action="store_true", default=None)
+ #parser.add_argument("--eval-random-audio-prompts", action="store_true", default=None)
args, unknown = parser.parse_known_args()
# create log folder
@@ -185,8 +201,8 @@ def train():
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
try:
- run_eval(engines, "subtrain", subtrain_dl)
- run_eval(engines, "val", val_dl)
+ run_eval(engines, "subtrain", subtrain_dl, args)
+ run_eval(engines, "val", val_dl, args)
except Exception as e:
_logger.warning(f"Error occurred while performing eval: {str(e)}")
_logger.warning(traceback.format_exc())
diff --git a/vall_e/webui.py b/vall_e/webui.py
index 7d5a5e7..f8307ca 100644
--- a/vall_e/webui.py
+++ b/vall_e/webui.py
@@ -19,7 +19,7 @@ from .train import train
from .utils import get_devices, setup_logging, timer
from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave
-from .data import get_lang_symmap
+from .data import get_lang_symmap, get_random_prompt
tts = None
@@ -284,42 +284,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
while True:
metrics = next(it)
yield metrics
-"""
-
-def get_random_prompt():
- harvard_sentences=[
- "The birch canoe slid on the smooth planks.",
- "Glue the sheet to the dark blue background.",
- "It's easy to tell the depth of a well.",
- "These days a chicken leg is a rare dish.",
- "Rice is often served in round bowls.",
- "The juice of lemons makes fine punch.",
- "The box was thrown beside the parked truck.",
- "The hogs were fed chopped corn and garbage.",
- "Four hours of steady work faced us.",
- "A large size in stockings is hard to sell.",
- "The boy was there when the sun rose.",
- "A rod is used to catch pink salmon.",
- "The source of the huge river is the clear spring.",
- "Kick the ball straight and follow through.",
- "Help the woman get back to her feet.",
- "A pot of tea helps to pass the evening.",
- "Smoky fires lack flame and heat.",
- "The soft cushion broke the man's fall.",
- "The salt breeze came across from the sea.",
- "The girl at the booth sold fifty bonds.",
- "The small pup gnawed a hole in the sock.",
- "The fish twisted and turned on the bent hook.",
- "Press the pants and sew a button on the vest.",
- "The swan dive was far short of perfect.",
- "The beauty of the view stunned the young boy.",
- "Two blue fish swam in the tank.",
- "Her purse was full of useless trash.",
- "The colt reared and threw the tall rider.",
- "It snowed, rained, and hailed the same morning.",
- "Read verse out loud for pleasure.",
- ]
- return random.choice(harvard_sentences)
+"""
# setup args
parser = argparse.ArgumentParser(allow_abbrev=False)