more demo page tweaks, added arg to force enable/disable LoRAs for inferencing (to-do: setup arg flags to handle this, and checkbox in web UI)
This commit is contained in:
parent
96d05be73c
commit
75a4c866d6
|
@ -25,7 +25,6 @@
|
|||
<tr>
|
||||
<th>Text</th>
|
||||
<th>Prompt</th>
|
||||
<th>Our VALL-E (No LoRA)</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>Ground Truth</th>
|
||||
</tr>
|
||||
|
|
|
@ -36,7 +36,7 @@ from tqdm.auto import tqdm
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@cache
|
||||
def get_random_prompts( validation=True, length=0, tokenized=False ):
|
||||
def get_random_prompts( validation=True, min_length=0, min_duration=6, tokenized=False ):
|
||||
sentences = [
|
||||
"The birch canoe slid on the smooth planks.",
|
||||
"Glue the sheet to the dark blue background.",
|
||||
|
@ -76,21 +76,27 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
|
|||
paths = list(itertools.chain.from_iterable(paths.values()))
|
||||
|
||||
for path in paths:
|
||||
duration = 0
|
||||
text_string = ""
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||
metadata = process_artifact_metadata( { "metadata": metadata } )
|
||||
text_string = metadata["text"] if "text" in metadata else ""
|
||||
duration = metadata['duration'] if "duration" in metadata else 0
|
||||
else:
|
||||
_, metadata = _load_quants(path, return_metadata=True)
|
||||
metadata = process_artifact_metadata( { "metadata": metadata } )
|
||||
text_string = metadata["text"] if "text" in metadata else ""
|
||||
duration = metadata['duration'] if "duration" in metadata else 0
|
||||
|
||||
if len( text_string ) < length:
|
||||
if len( text_string ) < min_length or duration < min_duration:
|
||||
continue
|
||||
|
||||
sentences.append( text_string )
|
||||
|
||||
# tokenize here because our harvard sentences need to be phonemized anyways
|
||||
if tokenized:
|
||||
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import argparse
|
|||
import base64
|
||||
import random
|
||||
import logging
|
||||
import time
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -42,6 +43,7 @@ def main():
|
|||
|
||||
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||
parser.add_argument("--skip-existing", action="store_true")
|
||||
parser.add_argument("--dataset-dir-name", type=str, default="dataset")
|
||||
parser.add_argument("--sample-from-dataset", action="store_true")
|
||||
parser.add_argument("--skip-loading-dataloader", action="store_true")
|
||||
parser.add_argument("--dataset-samples", type=int, default=0)
|
||||
|
@ -118,8 +120,8 @@ def main():
|
|||
"librispeech": args.demo_dir / "librispeech",
|
||||
}
|
||||
|
||||
if (args.demo_dir / "dataset").exists():
|
||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||
if (args.demo_dir / args.dataset_dir_name).exists():
|
||||
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||
|
||||
# pull from dataset samples
|
||||
if args.sample_from_dataset:
|
||||
|
@ -127,7 +129,7 @@ def main():
|
|||
cfg.dataset.sample_type = "path" if args.lora else "speaker"
|
||||
cfg.dataset.tasks_list = [ 'tts' ]
|
||||
|
||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||
|
||||
_logger.info("Loading dataloader...")
|
||||
dataloader = create_train_dataloader()
|
||||
|
@ -139,7 +141,7 @@ def main():
|
|||
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||
batch = dataloader.dataset[i]
|
||||
|
||||
dir = args.demo_dir / "dataset" / f'{i}'
|
||||
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
|
||||
|
||||
(dir / "out").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -181,14 +183,19 @@ def main():
|
|||
|
||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
||||
|
||||
if not args.random_prompts:
|
||||
extra_sources += [ reference ]
|
||||
|
||||
samples.append((
|
||||
text,
|
||||
[ prompt, out_path ] + extra_sources + [ reference ],
|
||||
[ prompt, out_path ] + extra_sources,
|
||||
))
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
continue
|
||||
|
||||
seed = args.seed if args.seed else int(time.time())
|
||||
|
||||
kwargs = dict(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
|
@ -202,17 +209,19 @@ def main():
|
|||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
seed=args.seed,
|
||||
seed=seed,
|
||||
tqdm=False,
|
||||
)
|
||||
|
||||
if args.lora:
|
||||
tts.enable_lora()
|
||||
tts.enable_lora() # I don't think this is necessary with the below
|
||||
kwargs["use_lora"] = True
|
||||
try:
|
||||
tts.inference( out_path=out_path_lora, **kwargs )
|
||||
except Exception as e:
|
||||
print(f'Error while processing {out_path}: {e}')
|
||||
tts.disable_lora()
|
||||
kwargs["use_lora"] = False
|
||||
try:
|
||||
tts.inference( out_path=out_path, **kwargs )
|
||||
except Exception as e:
|
||||
|
@ -233,8 +242,11 @@ def main():
|
|||
# write audio into template
|
||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
|
||||
if not args.lora:
|
||||
html = html.replace("\n\t\t\t\t\t<th>Our VALL-E (No LoRA)</th>", "")
|
||||
if args.lora:
|
||||
if args.random_prompts:
|
||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
|
||||
else:
|
||||
html = html.replace("<th>Our VALL-E</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<<th>Our VALL-E (LoRA)</th>")
|
||||
|
||||
# write demo page
|
||||
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
||||
|
|
|
@ -211,6 +211,7 @@ class TTS():
|
|||
out_path=None,
|
||||
|
||||
tqdm=True,
|
||||
use_lora=None,
|
||||
):
|
||||
lines = text.split("\n")
|
||||
|
||||
|
@ -255,6 +256,7 @@ class TTS():
|
|||
sampling_dry_allowed_length=dry_allowed_length,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
@ -298,6 +300,7 @@ class TTS():
|
|||
sampling_dry_allowed_length=dry_allowed_length,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||
|
@ -309,6 +312,7 @@ class TTS():
|
|||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
elif model_len is not None:
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
|
||||
|
@ -320,6 +324,7 @@ class TTS():
|
|||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
|
|
@ -58,6 +58,7 @@ class AR(Base):
|
|||
sampling_dry_allowed_length=2,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
@ -156,10 +157,8 @@ class AR(Base):
|
|||
)
|
||||
|
||||
# is AR
|
||||
"""
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
||||
"""
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
|
|
@ -65,6 +65,7 @@ class AR_NAR(Base):
|
|||
sampling_dry_allowed_length=2,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
):
|
||||
text_task = [ "stt" ]
|
||||
|
||||
|
@ -204,7 +205,7 @@ class AR_NAR(Base):
|
|||
break
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) )
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
|
||||
|
||||
|
@ -243,14 +244,11 @@ class AR_NAR(Base):
|
|||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self )
|
||||
|
||||
return prev_list
|
||||
|
||||
# is AR
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
|
||||
# STT
|
||||
start_slice = [ 0 for _ in range(batch_size) ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user