more demo page tweaks, added arg to force enable/disable LoRAs for inferencing (to-do: setup arg flags to handle this, and checkbox in web UI)

This commit is contained in:
mrq 2024-10-10 19:04:12 -05:00
parent 96d05be73c
commit 75a4c866d6
6 changed files with 39 additions and 20 deletions

View File

@ -25,7 +25,6 @@
<tr>
<th>Text</th>
<th>Prompt</th>
<th>Our VALL-E (No LoRA)</th>
<th>Our VALL-E</th>
<th>Ground Truth</th>
</tr>

View File

@ -36,7 +36,7 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__)
@cache
def get_random_prompts( validation=True, length=0, tokenized=False ):
def get_random_prompts( validation=True, min_length=0, min_duration=6, tokenized=False ):
sentences = [
"The birch canoe slid on the smooth planks.",
"Glue the sheet to the dark blue background.",
@ -76,21 +76,27 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
paths = list(itertools.chain.from_iterable(paths.values()))
for path in paths:
duration = 0
text_string = ""
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
metadata = process_artifact_metadata( { "metadata": metadata } )
text_string = metadata["text"] if "text" in metadata else ""
duration = metadata['duration'] if "duration" in metadata else 0
else:
_, metadata = _load_quants(path, return_metadata=True)
metadata = process_artifact_metadata( { "metadata": metadata } )
text_string = metadata["text"] if "text" in metadata else ""
duration = metadata['duration'] if "duration" in metadata else 0
if len( text_string ) < length:
if len( text_string ) < min_length or duration < min_duration:
continue
sentences.append( text_string )
# tokenize here because our harvard sentences need to be phonemized anyways
if tokenized:
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]

View File

@ -19,6 +19,7 @@ import argparse
import base64
import random
import logging
import time
_logger = logging.getLogger(__name__)
@ -42,6 +43,7 @@ def main():
parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--dataset-dir-name", type=str, default="dataset")
parser.add_argument("--sample-from-dataset", action="store_true")
parser.add_argument("--skip-loading-dataloader", action="store_true")
parser.add_argument("--dataset-samples", type=int, default=0)
@ -118,8 +120,8 @@ def main():
"librispeech": args.demo_dir / "librispeech",
}
if (args.demo_dir / "dataset").exists():
samples_dirs["dataset"] = args.demo_dir / "dataset"
if (args.demo_dir / args.dataset_dir_name).exists():
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
# pull from dataset samples
if args.sample_from_dataset:
@ -127,7 +129,7 @@ def main():
cfg.dataset.sample_type = "path" if args.lora else "speaker"
cfg.dataset.tasks_list = [ 'tts' ]
samples_dirs["dataset"] = args.demo_dir / "dataset"
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
_logger.info("Loading dataloader...")
dataloader = create_train_dataloader()
@ -139,7 +141,7 @@ def main():
for i in trange( num, desc="Sampling dataset for samples" ):
batch = dataloader.dataset[i]
dir = args.demo_dir / "dataset" / f'{i}'
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
(dir / "out").mkdir(parents=True, exist_ok=True)
@ -181,14 +183,19 @@ def main():
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
if not args.random_prompts:
extra_sources += [ reference ]
samples.append((
text,
[ prompt, out_path ] + extra_sources + [ reference ],
[ prompt, out_path ] + extra_sources,
))
if args.skip_existing and out_path.exists():
continue
seed = args.seed if args.seed else int(time.time())
kwargs = dict(
text=text,
references=[prompt],
@ -202,17 +209,19 @@ def main():
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
seed=args.seed,
seed=seed,
tqdm=False,
)
if args.lora:
tts.enable_lora()
tts.enable_lora() # I don't think this is necessary with the below
kwargs["use_lora"] = True
try:
tts.inference( out_path=out_path_lora, **kwargs )
except Exception as e:
print(f'Error while processing {out_path}: {e}')
tts.disable_lora()
kwargs["use_lora"] = False
try:
tts.inference( out_path=out_path, **kwargs )
except Exception as e:
@ -233,8 +242,11 @@ def main():
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
if not args.lora:
html = html.replace("\n\t\t\t\t\t<th>Our VALL-E (No LoRA)</th>", "")
if args.lora:
if args.random_prompts:
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
else:
html = html.replace("<th>Our VALL-E</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<<th>Our VALL-E (LoRA)</th>")
# write demo page
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )

View File

@ -211,6 +211,7 @@ class TTS():
out_path=None,
tqdm=True,
use_lora=None,
):
lines = text.split("\n")
@ -255,6 +256,7 @@ class TTS():
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm,
use_lora=use_lora,
)
else:
raise Exception("!")
@ -298,6 +300,7 @@ class TTS():
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm,
use_lora=use_lora,
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
@ -309,6 +312,7 @@ class TTS():
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
disable_tqdm=not tqdm,
use_lora=use_lora,
)
elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
@ -320,6 +324,7 @@ class TTS():
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
disable_tqdm=not tqdm,
use_lora=use_lora,
)
else:
raise Exception("!")

View File

@ -58,6 +58,7 @@ class AR(Base):
sampling_dry_allowed_length=2,
disable_tqdm=False,
use_lora=None,
):
device = text_list[0].device
batch_size = len(text_list)
@ -156,10 +157,8 @@ class AR(Base):
)
# is AR
"""
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) )
"""
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()

View File

@ -65,6 +65,7 @@ class AR_NAR(Base):
sampling_dry_allowed_length=2,
disable_tqdm=False,
use_lora=None,
):
text_task = [ "stt" ]
@ -204,7 +205,7 @@ class AR_NAR(Base):
break
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) )
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
@ -243,14 +244,11 @@ class AR_NAR(Base):
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
if cfg.lora is not None:
enable_lora( self )
return prev_list
# is AR
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) )
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# STT
start_slice = [ 0 for _ in range(batch_size) ]