more demo page tweaks, added arg to force enable/disable LoRAs for inferencing (to-do: setup arg flags to handle this, and checkbox in web UI)

This commit is contained in:
mrq 2024-10-10 19:04:12 -05:00
parent 96d05be73c
commit 75a4c866d6
6 changed files with 39 additions and 20 deletions

View File

@ -25,7 +25,6 @@
<tr> <tr>
<th>Text</th> <th>Text</th>
<th>Prompt</th> <th>Prompt</th>
<th>Our VALL-E (No LoRA)</th>
<th>Our VALL-E</th> <th>Our VALL-E</th>
<th>Ground Truth</th> <th>Ground Truth</th>
</tr> </tr>

View File

@ -36,7 +36,7 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@cache @cache
def get_random_prompts( validation=True, length=0, tokenized=False ): def get_random_prompts( validation=True, min_length=0, min_duration=6, tokenized=False ):
sentences = [ sentences = [
"The birch canoe slid on the smooth planks.", "The birch canoe slid on the smooth planks.",
"Glue the sheet to the dark blue background.", "Glue the sheet to the dark blue background.",
@ -76,21 +76,27 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
paths = list(itertools.chain.from_iterable(paths.values())) paths = list(itertools.chain.from_iterable(paths.values()))
for path in paths: for path in paths:
duration = 0
text_string = "" text_string = ""
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path) key = _get_hdf5_path(path)
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() } metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
metadata = process_artifact_metadata( { "metadata": metadata } )
text_string = metadata["text"] if "text" in metadata else "" text_string = metadata["text"] if "text" in metadata else ""
duration = metadata['duration'] if "duration" in metadata else 0
else: else:
_, metadata = _load_quants(path, return_metadata=True) _, metadata = _load_quants(path, return_metadata=True)
metadata = process_artifact_metadata( { "metadata": metadata } )
text_string = metadata["text"] if "text" in metadata else "" text_string = metadata["text"] if "text" in metadata else ""
duration = metadata['duration'] if "duration" in metadata else 0
if len( text_string ) < length: if len( text_string ) < min_length or duration < min_duration:
continue continue
sentences.append( text_string ) sentences.append( text_string )
# tokenize here because our harvard sentences need to be phonemized anyways
if tokenized: if tokenized:
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ] return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]

View File

@ -19,6 +19,7 @@ import argparse
import base64 import base64
import random import random
import logging import logging
import time
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -42,6 +43,7 @@ def main():
parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--dataset-dir-name", type=str, default="dataset")
parser.add_argument("--sample-from-dataset", action="store_true") parser.add_argument("--sample-from-dataset", action="store_true")
parser.add_argument("--skip-loading-dataloader", action="store_true") parser.add_argument("--skip-loading-dataloader", action="store_true")
parser.add_argument("--dataset-samples", type=int, default=0) parser.add_argument("--dataset-samples", type=int, default=0)
@ -118,8 +120,8 @@ def main():
"librispeech": args.demo_dir / "librispeech", "librispeech": args.demo_dir / "librispeech",
} }
if (args.demo_dir / "dataset").exists(): if (args.demo_dir / args.dataset_dir_name).exists():
samples_dirs["dataset"] = args.demo_dir / "dataset" samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
# pull from dataset samples # pull from dataset samples
if args.sample_from_dataset: if args.sample_from_dataset:
@ -127,7 +129,7 @@ def main():
cfg.dataset.sample_type = "path" if args.lora else "speaker" cfg.dataset.sample_type = "path" if args.lora else "speaker"
cfg.dataset.tasks_list = [ 'tts' ] cfg.dataset.tasks_list = [ 'tts' ]
samples_dirs["dataset"] = args.demo_dir / "dataset" samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
_logger.info("Loading dataloader...") _logger.info("Loading dataloader...")
dataloader = create_train_dataloader() dataloader = create_train_dataloader()
@ -139,7 +141,7 @@ def main():
for i in trange( num, desc="Sampling dataset for samples" ): for i in trange( num, desc="Sampling dataset for samples" ):
batch = dataloader.dataset[i] batch = dataloader.dataset[i]
dir = args.demo_dir / "dataset" / f'{i}' dir = args.demo_dir / args.dataset_dir_name / f'{i}'
(dir / "out").mkdir(parents=True, exist_ok=True) (dir / "out").mkdir(parents=True, exist_ok=True)
@ -181,14 +183,19 @@ def main():
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else []) extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
if not args.random_prompts:
extra_sources += [ reference ]
samples.append(( samples.append((
text, text,
[ prompt, out_path ] + extra_sources + [ reference ], [ prompt, out_path ] + extra_sources,
)) ))
if args.skip_existing and out_path.exists(): if args.skip_existing and out_path.exists():
continue continue
seed = args.seed if args.seed else int(time.time())
kwargs = dict( kwargs = dict(
text=text, text=text,
references=[prompt], references=[prompt],
@ -202,17 +209,19 @@ def main():
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
seed=args.seed, seed=seed,
tqdm=False, tqdm=False,
) )
if args.lora: if args.lora:
tts.enable_lora() tts.enable_lora() # I don't think this is necessary with the below
kwargs["use_lora"] = True
try: try:
tts.inference( out_path=out_path_lora, **kwargs ) tts.inference( out_path=out_path_lora, **kwargs )
except Exception as e: except Exception as e:
print(f'Error while processing {out_path}: {e}') print(f'Error while processing {out_path}: {e}')
tts.disable_lora() tts.disable_lora()
kwargs["use_lora"] = False
try: try:
tts.inference( out_path=out_path, **kwargs ) tts.inference( out_path=out_path, **kwargs )
except Exception as e: except Exception as e:
@ -233,8 +242,11 @@ def main():
# write audio into template # write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
if not args.lora: if args.lora:
html = html.replace("\n\t\t\t\t\t<th>Our VALL-E (No LoRA)</th>", "") if args.random_prompts:
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
else:
html = html.replace("<th>Our VALL-E</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<<th>Our VALL-E (LoRA)</th>")
# write demo page # write demo page
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html ) open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )

View File

@ -211,6 +211,7 @@ class TTS():
out_path=None, out_path=None,
tqdm=True, tqdm=True,
use_lora=None,
): ):
lines = text.split("\n") lines = text.split("\n")
@ -255,6 +256,7 @@ class TTS():
sampling_dry_allowed_length=dry_allowed_length, sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora,
) )
else: else:
raise Exception("!") raise Exception("!")
@ -298,6 +300,7 @@ class TTS():
sampling_dry_allowed_length=dry_allowed_length, sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora,
) )
resps_list = model_nar( resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
@ -309,6 +312,7 @@ class TTS():
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora,
) )
elif model_len is not None: elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
@ -320,6 +324,7 @@ class TTS():
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora,
) )
else: else:
raise Exception("!") raise Exception("!")

View File

@ -58,6 +58,7 @@ class AR(Base):
sampling_dry_allowed_length=2, sampling_dry_allowed_length=2,
disable_tqdm=False, disable_tqdm=False,
use_lora=None,
): ):
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
@ -156,10 +157,8 @@ class AR(Base):
) )
# is AR # is AR
"""
if cfg.lora is not None: if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) ) enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
"""
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool() stopped = torch.zeros(batch_size, device=device).bool()

View File

@ -65,6 +65,7 @@ class AR_NAR(Base):
sampling_dry_allowed_length=2, sampling_dry_allowed_length=2,
disable_tqdm=False, disable_tqdm=False,
use_lora=None,
): ):
text_task = [ "stt" ] text_task = [ "stt" ]
@ -204,7 +205,7 @@ class AR_NAR(Base):
break break
if cfg.lora is not None: if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) ) enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
@ -243,14 +244,11 @@ class AR_NAR(Base):
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
if cfg.lora is not None:
enable_lora( self )
return prev_list return prev_list
# is AR # is AR
if cfg.lora is not None: if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) ) enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# STT # STT
start_slice = [ 0 for _ in range(batch_size) ] start_slice = [ 0 for _ in range(batch_size) ]