more demo page tweaks, added arg to force enable/disable LoRAs for inferencing (to-do: setup arg flags to handle this, and checkbox in web UI)
This commit is contained in:
parent
96d05be73c
commit
75a4c866d6
|
@ -25,7 +25,6 @@
|
||||||
<tr>
|
<tr>
|
||||||
<th>Text</th>
|
<th>Text</th>
|
||||||
<th>Prompt</th>
|
<th>Prompt</th>
|
||||||
<th>Our VALL-E (No LoRA)</th>
|
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
<th>Ground Truth</th>
|
<th>Ground Truth</th>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
|
@ -36,7 +36,7 @@ from tqdm.auto import tqdm
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_random_prompts( validation=True, length=0, tokenized=False ):
|
def get_random_prompts( validation=True, min_length=0, min_duration=6, tokenized=False ):
|
||||||
sentences = [
|
sentences = [
|
||||||
"The birch canoe slid on the smooth planks.",
|
"The birch canoe slid on the smooth planks.",
|
||||||
"Glue the sheet to the dark blue background.",
|
"Glue the sheet to the dark blue background.",
|
||||||
|
@ -76,21 +76,27 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
|
||||||
paths = list(itertools.chain.from_iterable(paths.values()))
|
paths = list(itertools.chain.from_iterable(paths.values()))
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
|
duration = 0
|
||||||
text_string = ""
|
text_string = ""
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||||
|
metadata = process_artifact_metadata( { "metadata": metadata } )
|
||||||
text_string = metadata["text"] if "text" in metadata else ""
|
text_string = metadata["text"] if "text" in metadata else ""
|
||||||
|
duration = metadata['duration'] if "duration" in metadata else 0
|
||||||
else:
|
else:
|
||||||
_, metadata = _load_quants(path, return_metadata=True)
|
_, metadata = _load_quants(path, return_metadata=True)
|
||||||
|
metadata = process_artifact_metadata( { "metadata": metadata } )
|
||||||
text_string = metadata["text"] if "text" in metadata else ""
|
text_string = metadata["text"] if "text" in metadata else ""
|
||||||
|
duration = metadata['duration'] if "duration" in metadata else 0
|
||||||
|
|
||||||
if len( text_string ) < length:
|
if len( text_string ) < min_length or duration < min_duration:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sentences.append( text_string )
|
sentences.append( text_string )
|
||||||
|
|
||||||
|
# tokenize here because our harvard sentences need to be phonemized anyways
|
||||||
if tokenized:
|
if tokenized:
|
||||||
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]
|
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ import argparse
|
||||||
import base64
|
import base64
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -42,6 +43,7 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--demo-dir", type=Path, default=None)
|
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||||
parser.add_argument("--skip-existing", action="store_true")
|
parser.add_argument("--skip-existing", action="store_true")
|
||||||
|
parser.add_argument("--dataset-dir-name", type=str, default="dataset")
|
||||||
parser.add_argument("--sample-from-dataset", action="store_true")
|
parser.add_argument("--sample-from-dataset", action="store_true")
|
||||||
parser.add_argument("--skip-loading-dataloader", action="store_true")
|
parser.add_argument("--skip-loading-dataloader", action="store_true")
|
||||||
parser.add_argument("--dataset-samples", type=int, default=0)
|
parser.add_argument("--dataset-samples", type=int, default=0)
|
||||||
|
@ -118,8 +120,8 @@ def main():
|
||||||
"librispeech": args.demo_dir / "librispeech",
|
"librispeech": args.demo_dir / "librispeech",
|
||||||
}
|
}
|
||||||
|
|
||||||
if (args.demo_dir / "dataset").exists():
|
if (args.demo_dir / args.dataset_dir_name).exists():
|
||||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||||
|
|
||||||
# pull from dataset samples
|
# pull from dataset samples
|
||||||
if args.sample_from_dataset:
|
if args.sample_from_dataset:
|
||||||
|
@ -127,7 +129,7 @@ def main():
|
||||||
cfg.dataset.sample_type = "path" if args.lora else "speaker"
|
cfg.dataset.sample_type = "path" if args.lora else "speaker"
|
||||||
cfg.dataset.tasks_list = [ 'tts' ]
|
cfg.dataset.tasks_list = [ 'tts' ]
|
||||||
|
|
||||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||||
|
|
||||||
_logger.info("Loading dataloader...")
|
_logger.info("Loading dataloader...")
|
||||||
dataloader = create_train_dataloader()
|
dataloader = create_train_dataloader()
|
||||||
|
@ -139,7 +141,7 @@ def main():
|
||||||
for i in trange( num, desc="Sampling dataset for samples" ):
|
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||||
batch = dataloader.dataset[i]
|
batch = dataloader.dataset[i]
|
||||||
|
|
||||||
dir = args.demo_dir / "dataset" / f'{i}'
|
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
|
||||||
|
|
||||||
(dir / "out").mkdir(parents=True, exist_ok=True)
|
(dir / "out").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -181,14 +183,19 @@ def main():
|
||||||
|
|
||||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
||||||
|
|
||||||
|
if not args.random_prompts:
|
||||||
|
extra_sources += [ reference ]
|
||||||
|
|
||||||
samples.append((
|
samples.append((
|
||||||
text,
|
text,
|
||||||
[ prompt, out_path ] + extra_sources + [ reference ],
|
[ prompt, out_path ] + extra_sources,
|
||||||
))
|
))
|
||||||
|
|
||||||
if args.skip_existing and out_path.exists():
|
if args.skip_existing and out_path.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
seed = args.seed if args.seed else int(time.time())
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
text=text,
|
text=text,
|
||||||
references=[prompt],
|
references=[prompt],
|
||||||
|
@ -202,17 +209,19 @@ def main():
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||||
seed=args.seed,
|
seed=seed,
|
||||||
tqdm=False,
|
tqdm=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora:
|
if args.lora:
|
||||||
tts.enable_lora()
|
tts.enable_lora() # I don't think this is necessary with the below
|
||||||
|
kwargs["use_lora"] = True
|
||||||
try:
|
try:
|
||||||
tts.inference( out_path=out_path_lora, **kwargs )
|
tts.inference( out_path=out_path_lora, **kwargs )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Error while processing {out_path}: {e}')
|
print(f'Error while processing {out_path}: {e}')
|
||||||
tts.disable_lora()
|
tts.disable_lora()
|
||||||
|
kwargs["use_lora"] = False
|
||||||
try:
|
try:
|
||||||
tts.inference( out_path=out_path, **kwargs )
|
tts.inference( out_path=out_path, **kwargs )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -233,8 +242,11 @@ def main():
|
||||||
# write audio into template
|
# write audio into template
|
||||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||||
|
|
||||||
if not args.lora:
|
if args.lora:
|
||||||
html = html.replace("\n\t\t\t\t\t<th>Our VALL-E (No LoRA)</th>", "")
|
if args.random_prompts:
|
||||||
|
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
|
||||||
|
else:
|
||||||
|
html = html.replace("<th>Our VALL-E</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<<th>Our VALL-E (LoRA)</th>")
|
||||||
|
|
||||||
# write demo page
|
# write demo page
|
||||||
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
||||||
|
|
|
@ -211,6 +211,7 @@ class TTS():
|
||||||
out_path=None,
|
out_path=None,
|
||||||
|
|
||||||
tqdm=True,
|
tqdm=True,
|
||||||
|
use_lora=None,
|
||||||
):
|
):
|
||||||
lines = text.split("\n")
|
lines = text.split("\n")
|
||||||
|
|
||||||
|
@ -255,6 +256,7 @@ class TTS():
|
||||||
sampling_dry_allowed_length=dry_allowed_length,
|
sampling_dry_allowed_length=dry_allowed_length,
|
||||||
|
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not tqdm,
|
||||||
|
use_lora=use_lora,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("!")
|
raise Exception("!")
|
||||||
|
@ -298,6 +300,7 @@ class TTS():
|
||||||
sampling_dry_allowed_length=dry_allowed_length,
|
sampling_dry_allowed_length=dry_allowed_length,
|
||||||
|
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not tqdm,
|
||||||
|
use_lora=use_lora,
|
||||||
)
|
)
|
||||||
resps_list = model_nar(
|
resps_list = model_nar(
|
||||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||||
|
@ -309,6 +312,7 @@ class TTS():
|
||||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
|
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not tqdm,
|
||||||
|
use_lora=use_lora,
|
||||||
)
|
)
|
||||||
elif model_len is not None:
|
elif model_len is not None:
|
||||||
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
|
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
|
||||||
|
@ -320,6 +324,7 @@ class TTS():
|
||||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
|
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not tqdm,
|
||||||
|
use_lora=use_lora,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("!")
|
raise Exception("!")
|
||||||
|
|
|
@ -58,6 +58,7 @@ class AR(Base):
|
||||||
sampling_dry_allowed_length=2,
|
sampling_dry_allowed_length=2,
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
|
use_lora=None,
|
||||||
):
|
):
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
|
@ -156,10 +157,8 @@ class AR(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
# is AR
|
# is AR
|
||||||
"""
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||||
"""
|
|
||||||
|
|
||||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
|
@ -65,6 +65,7 @@ class AR_NAR(Base):
|
||||||
sampling_dry_allowed_length=2,
|
sampling_dry_allowed_length=2,
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
|
use_lora=None,
|
||||||
):
|
):
|
||||||
text_task = [ "stt" ]
|
text_task = [ "stt" ]
|
||||||
|
|
||||||
|
@ -204,7 +205,7 @@ class AR_NAR(Base):
|
||||||
break
|
break
|
||||||
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
enable_lora( self, cfg.lora.active_level( level ) )
|
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||||
|
|
||||||
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
|
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
|
||||||
|
|
||||||
|
@ -243,14 +244,11 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||||
|
|
||||||
if cfg.lora is not None:
|
|
||||||
enable_lora( self )
|
|
||||||
|
|
||||||
return prev_list
|
return prev_list
|
||||||
|
|
||||||
# is AR
|
# is AR
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||||
|
|
||||||
# STT
|
# STT
|
||||||
start_slice = [ 0 for _ in range(batch_size) ]
|
start_slice = [ 0 for _ in range(batch_size) ]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user