added --eval-random-text-prompts to use random text prompts for eval pass, added --random-prompts for demo page and --lora to use a sample with the lora disabled, probably finally fixed validation dataloader breaking on eval
This commit is contained in:
parent
52299127ab
commit
2ea978f318
|
@ -96,6 +96,10 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
|
||||||
|
|
||||||
The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
||||||
|
|
||||||
|
Some additional flags can be passed as well:
|
||||||
|
* `--eval`: only run the evaluation / validation pass, then exit afterwards.
|
||||||
|
* `--eval-random-text-prompts`: use random text prompts for the evaluation pass, rather than the provided text prompts in the dataset.
|
||||||
|
|
||||||
### Finetuning
|
### Finetuning
|
||||||
|
|
||||||
Finetuning can be done by training the full model, or using a LoRA.
|
Finetuning can be done by training the full model, or using a LoRA.
|
||||||
|
|
|
@ -35,6 +35,72 @@ from tqdm.auto import tqdm
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_random_prompts( validation=True, length=0, tokenized=False ):
|
||||||
|
sentences = [
|
||||||
|
"The birch canoe slid on the smooth planks.",
|
||||||
|
"Glue the sheet to the dark blue background.",
|
||||||
|
"It's easy to tell the depth of a well.",
|
||||||
|
"These days a chicken leg is a rare dish.",
|
||||||
|
"Rice is often served in round bowls.",
|
||||||
|
"The juice of lemons makes fine punch.",
|
||||||
|
"The box was thrown beside the parked truck.",
|
||||||
|
"The hogs were fed chopped corn and garbage.",
|
||||||
|
"Four hours of steady work faced us.",
|
||||||
|
"A large size in stockings is hard to sell.",
|
||||||
|
"The boy was there when the sun rose.",
|
||||||
|
"A rod is used to catch pink salmon.",
|
||||||
|
"The source of the huge river is the clear spring.",
|
||||||
|
"Kick the ball straight and follow through.",
|
||||||
|
"Help the woman get back to her feet.",
|
||||||
|
"A pot of tea helps to pass the evening.",
|
||||||
|
"Smoky fires lack flame and heat.",
|
||||||
|
"The soft cushion broke the man's fall.",
|
||||||
|
"The salt breeze came across from the sea.",
|
||||||
|
"The girl at the booth sold fifty bonds.",
|
||||||
|
"The small pup gnawed a hole in the sock.",
|
||||||
|
"The fish twisted and turned on the bent hook.",
|
||||||
|
"Press the pants and sew a button on the vest.",
|
||||||
|
"The swan dive was far short of perfect.",
|
||||||
|
"The beauty of the view stunned the young boy.",
|
||||||
|
"Two blue fish swam in the tank.",
|
||||||
|
"Her purse was full of useless trash.",
|
||||||
|
"The colt reared and threw the tall rider.",
|
||||||
|
"It snowed, rained, and hailed the same morning.",
|
||||||
|
"Read verse out loud for pleasure.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Pull from validation dataset if existing + requested
|
||||||
|
if validation and cfg.dataset.validation:
|
||||||
|
paths = _load_paths(cfg.dataset.validation, type="validation")
|
||||||
|
paths = list(itertools.chain.from_iterable(paths.values()))
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
text_string = ""
|
||||||
|
if cfg.dataset.use_hdf5:
|
||||||
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
|
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||||
|
text_string = metadata["text"] if "text" in metadata else ""
|
||||||
|
else:
|
||||||
|
_, metadata = _load_quants(path, return_metadata=True)
|
||||||
|
text_string = metadata["text"] if "text" in metadata else ""
|
||||||
|
|
||||||
|
if len( text_string ) < length:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sentences.append( text_string )
|
||||||
|
|
||||||
|
if tokenized:
|
||||||
|
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]
|
||||||
|
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
# samples a random text prompt
|
||||||
|
def get_random_prompt( *args, **kwargs ):
|
||||||
|
# Harvard sentences
|
||||||
|
return random.choice(get_random_prompts( *args, **kwargs ))
|
||||||
|
|
||||||
# fold into a typical LLM sequence (one embedding rather than split embeddings)
|
# fold into a typical LLM sequence (one embedding rather than split embeddings)
|
||||||
def fold_inputs(
|
def fold_inputs(
|
||||||
text_list = [],
|
text_list = [],
|
||||||
|
@ -718,7 +784,7 @@ class Dataset(_Dataset):
|
||||||
if len(self.paths) == 0:
|
if len(self.paths) == 0:
|
||||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path" and self.training:
|
||||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||||
self.sampler = BatchedOrderedSampler(
|
self.sampler = BatchedOrderedSampler(
|
||||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||||
|
@ -735,8 +801,6 @@ class Dataset(_Dataset):
|
||||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
||||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
# loading validation state dict causes issues
|
|
||||||
if self.dataset_type != "validation":
|
|
||||||
self.load_state_dict()
|
self.load_state_dict()
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
@ -789,6 +853,9 @@ class Dataset(_Dataset):
|
||||||
torch_save(state_dict, path)
|
torch_save(state_dict, path)
|
||||||
|
|
||||||
def load_state_dict(self, path = None):
|
def load_state_dict(self, path = None):
|
||||||
|
if not self.training:
|
||||||
|
return
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
path = self.sampler_state_dict_path
|
path = self.sampler_state_dict_path
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from .inference import TTS
|
from .inference import TTS
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .data import create_train_dataloader, create_val_dataloader
|
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
||||||
from .emb.qnt import decode_to_file
|
from .emb.qnt import decode_to_file
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
@ -79,6 +79,9 @@ def main():
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--dtype", type=str, default=None)
|
parser.add_argument("--dtype", type=str, default=None)
|
||||||
|
|
||||||
|
parser.add_argument("--random-prompts", action="store_true")
|
||||||
|
parser.add_argument("--lora", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||||
|
@ -142,7 +145,7 @@ def main():
|
||||||
|
|
||||||
metadata = batch["metadata"]
|
metadata = batch["metadata"]
|
||||||
|
|
||||||
text = metadata["text"]
|
text = get_random_prompt() if args.random_prompts else metadata["text"]
|
||||||
language = metadata["language"].lower()
|
language = metadata["language"].lower()
|
||||||
|
|
||||||
prompt = dir / "prompt.wav"
|
prompt = dir / "prompt.wav"
|
||||||
|
@ -174,8 +177,9 @@ def main():
|
||||||
prompt = dir / "prompt.wav"
|
prompt = dir / "prompt.wav"
|
||||||
reference = dir / "reference.wav"
|
reference = dir / "reference.wav"
|
||||||
out_path = dir / "out" / "ours.wav"
|
out_path = dir / "out" / "ours.wav"
|
||||||
|
out_path_lora = dir / "out" / "ours_lora.wav"
|
||||||
|
|
||||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else []
|
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
||||||
|
|
||||||
samples.append((
|
samples.append((
|
||||||
text,
|
text,
|
||||||
|
@ -185,12 +189,10 @@ def main():
|
||||||
if args.skip_existing and out_path.exists():
|
if args.skip_existing and out_path.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
kwargs = dict(
|
||||||
tts.inference(
|
|
||||||
text=text,
|
text=text,
|
||||||
references=[prompt],
|
references=[prompt],
|
||||||
language=language,
|
language=language,
|
||||||
out_path=out_path,
|
|
||||||
input_prompt_length=args.input_prompt_length,
|
input_prompt_length=args.input_prompt_length,
|
||||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||||
|
@ -203,8 +205,19 @@ def main():
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
tqdm=False,
|
tqdm=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.lora:
|
||||||
|
tts.enable_lora()
|
||||||
|
try:
|
||||||
|
tts.inference( out_path=out_path_lora, **kwargs )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Error while processing {out_path}: {e}')
|
print(f'Error while processing {out_path}: {e}')
|
||||||
|
tts.disable_lora()
|
||||||
|
try:
|
||||||
|
tts.inference( out_path=out_path, **kwargs )
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error while processing {out_path}: {e}')
|
||||||
|
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
samples = [
|
samples = [
|
||||||
|
|
|
@ -221,6 +221,7 @@ def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--input-speaker", type=Path, default=None)
|
parser.add_argument("--input-speaker", type=Path, default=None)
|
||||||
|
parser.add_argument("--input-voice", type=str, default=None)
|
||||||
parser.add_argument("--use-dataset", action="store_true")
|
parser.add_argument("--use-dataset", action="store_true")
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path)
|
parser.add_argument("--yaml", type=Path)
|
||||||
|
@ -255,6 +256,9 @@ def main():
|
||||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if args.input_voice and speaker_name != args.input_voice:
|
||||||
|
return
|
||||||
|
|
||||||
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
|
metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
|
||||||
metadata = json_read( metadata_path, default={} )
|
metadata = json_read( metadata_path, default={} )
|
||||||
metadata_keys = list(metadata.keys()) if metadata else []
|
metadata_keys = list(metadata.keys()) if metadata else []
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .utils import to_device, set_seed, wrapper as ml
|
||||||
|
|
||||||
from .config import cfg, Config
|
from .config import cfg, Config
|
||||||
from .models import get_models
|
from .models import get_models
|
||||||
|
from .models.lora import enable_lora
|
||||||
from .engines import load_engines, deepspeed_available
|
from .engines import load_engines, deepspeed_available
|
||||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
||||||
|
|
||||||
|
@ -77,6 +78,13 @@ class TTS():
|
||||||
self.symmap = get_phone_symmap()
|
self.symmap = get_phone_symmap()
|
||||||
_logger.info("Loaded model")
|
_logger.info("Loaded model")
|
||||||
|
|
||||||
|
def enable_lora( self, enabled=True ):
|
||||||
|
for name, engine in self.engines.items():
|
||||||
|
enable_lora( engine.module, mode = enabled )
|
||||||
|
|
||||||
|
def disable_lora( self ):
|
||||||
|
return self.enable_lora( enabled=False )
|
||||||
|
|
||||||
def encode_text( self, text, language="en" ):
|
def encode_text( self, text, language="en" ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
if isinstance( text, Tensor ):
|
if isinstance( text, Tensor ):
|
||||||
|
|
|
@ -156,8 +156,10 @@ class AR(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
# is AR
|
# is AR
|
||||||
|
"""
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
enable_lora( self, cfg.lora.active_level( 0 ) )
|
||||||
|
"""
|
||||||
|
|
||||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# todo: clean this mess up
|
# todo: clean this mess up
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .data import create_train_val_dataloader
|
from .data import create_train_val_dataloader, get_random_prompt, tokenize
|
||||||
from .emb import qnt
|
from .emb import qnt, g2p
|
||||||
|
|
||||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||||
from .data import fold_inputs, unfold_outputs
|
from .data import fold_inputs, unfold_outputs
|
||||||
|
@ -57,10 +57,13 @@ def train_feeder(engine, batch):
|
||||||
return loss, stats
|
return loss, stats
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_eval(engines, eval_name, dl):
|
def run_eval(engines, eval_name, dl, args=None):
|
||||||
stats = defaultdict(list)
|
stats = defaultdict(list)
|
||||||
stats['loss'] = []
|
stats['loss'] = []
|
||||||
|
|
||||||
|
if cfg.evaluation.size == 0:
|
||||||
|
return
|
||||||
|
|
||||||
def process( name, batch, resps_list ):
|
def process( name, batch, resps_list ):
|
||||||
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
|
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
|
||||||
if len(hyp) == 0:
|
if len(hyp) == 0:
|
||||||
|
@ -84,12 +87,17 @@ def run_eval(engines, eval_name, dl):
|
||||||
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
|
||||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||||||
|
|
||||||
|
if ref is not None:
|
||||||
|
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||||
|
|
||||||
if prom is not None:
|
if prom is not None:
|
||||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||||
|
|
||||||
# pseudo loss calculation since we don't get the logits during eval
|
# naive loss calculation
|
||||||
|
# to-do: find a better way to calculate this / a better metric
|
||||||
|
if ref is not None:
|
||||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||||
ref_audio = ref_audio[..., 0:min_length]
|
ref_audio = ref_audio[..., 0:min_length]
|
||||||
hyp_audio = hyp_audio[..., 0:min_length]
|
hyp_audio = hyp_audio[..., 0:min_length]
|
||||||
|
@ -105,11 +113,6 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
|
|
||||||
processed += batch_size
|
|
||||||
|
|
||||||
for name in engines:
|
|
||||||
engine = engines[name]
|
|
||||||
|
|
||||||
# to-do: eval for text tasks
|
# to-do: eval for text tasks
|
||||||
has_stt = False
|
has_stt = False
|
||||||
for i, task in enumerate( batch["task"] ):
|
for i, task in enumerate( batch["task"] ):
|
||||||
|
@ -119,6 +122,17 @@ def run_eval(engines, eval_name, dl):
|
||||||
batch["task"][i] = "tts"
|
batch["task"][i] = "tts"
|
||||||
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
||||||
|
|
||||||
|
# random prompts requested
|
||||||
|
if args and args.eval_random_text_prompts and eval_name == "subtrain":
|
||||||
|
for i, _ in enumerate(batch["text"]):
|
||||||
|
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
||||||
|
batch["resps"][i] = None
|
||||||
|
|
||||||
|
processed += batch_size
|
||||||
|
for name in engines:
|
||||||
|
engine = engines[name]
|
||||||
|
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
|
@ -157,7 +171,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
_logger.info(f"Validation Metrics (STT): {text_list}")
|
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
||||||
engines_stats = {
|
engines_stats = {
|
||||||
f'{name}.{eval_name}': stats,
|
f'{name}.{eval_name}': stats,
|
||||||
"it": engines.global_step,
|
"it": engines.global_step,
|
||||||
|
@ -170,6 +184,8 @@ def run_eval(engines, eval_name, dl):
|
||||||
def train():
|
def train():
|
||||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||||
parser.add_argument("--eval", action="store_true", default=None)
|
parser.add_argument("--eval", action="store_true", default=None)
|
||||||
|
parser.add_argument("--eval-random-text-prompts", action="store_true", default=None)
|
||||||
|
#parser.add_argument("--eval-random-audio-prompts", action="store_true", default=None)
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
# create log folder
|
# create log folder
|
||||||
|
@ -185,8 +201,8 @@ def train():
|
||||||
engines.eval()
|
engines.eval()
|
||||||
# wrapped in a try block because it's sometimes prone to breaking
|
# wrapped in a try block because it's sometimes prone to breaking
|
||||||
try:
|
try:
|
||||||
run_eval(engines, "subtrain", subtrain_dl)
|
run_eval(engines, "subtrain", subtrain_dl, args)
|
||||||
run_eval(engines, "val", val_dl)
|
run_eval(engines, "val", val_dl, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
||||||
_logger.warning(traceback.format_exc())
|
_logger.warning(traceback.format_exc())
|
||||||
|
|
|
@ -19,7 +19,7 @@ from .train import train
|
||||||
from .utils import get_devices, setup_logging, timer
|
from .utils import get_devices, setup_logging, timer
|
||||||
from .utils.io import json_read, json_stringify
|
from .utils.io import json_read, json_stringify
|
||||||
from .emb.qnt import decode_to_wave
|
from .emb.qnt import decode_to_wave
|
||||||
from .data import get_lang_symmap
|
from .data import get_lang_symmap, get_random_prompt
|
||||||
|
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
|
@ -286,41 +286,6 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
yield metrics
|
yield metrics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_random_prompt():
|
|
||||||
harvard_sentences=[
|
|
||||||
"The birch canoe slid on the smooth planks.",
|
|
||||||
"Glue the sheet to the dark blue background.",
|
|
||||||
"It's easy to tell the depth of a well.",
|
|
||||||
"These days a chicken leg is a rare dish.",
|
|
||||||
"Rice is often served in round bowls.",
|
|
||||||
"The juice of lemons makes fine punch.",
|
|
||||||
"The box was thrown beside the parked truck.",
|
|
||||||
"The hogs were fed chopped corn and garbage.",
|
|
||||||
"Four hours of steady work faced us.",
|
|
||||||
"A large size in stockings is hard to sell.",
|
|
||||||
"The boy was there when the sun rose.",
|
|
||||||
"A rod is used to catch pink salmon.",
|
|
||||||
"The source of the huge river is the clear spring.",
|
|
||||||
"Kick the ball straight and follow through.",
|
|
||||||
"Help the woman get back to her feet.",
|
|
||||||
"A pot of tea helps to pass the evening.",
|
|
||||||
"Smoky fires lack flame and heat.",
|
|
||||||
"The soft cushion broke the man's fall.",
|
|
||||||
"The salt breeze came across from the sea.",
|
|
||||||
"The girl at the booth sold fifty bonds.",
|
|
||||||
"The small pup gnawed a hole in the sock.",
|
|
||||||
"The fish twisted and turned on the bent hook.",
|
|
||||||
"Press the pants and sew a button on the vest.",
|
|
||||||
"The swan dive was far short of perfect.",
|
|
||||||
"The beauty of the view stunned the young boy.",
|
|
||||||
"Two blue fish swam in the tank.",
|
|
||||||
"Her purse was full of useless trash.",
|
|
||||||
"The colt reared and threw the tall rider.",
|
|
||||||
"It snowed, rained, and hailed the same morning.",
|
|
||||||
"Read verse out loud for pleasure.",
|
|
||||||
]
|
|
||||||
return random.choice(harvard_sentences)
|
|
||||||
|
|
||||||
# setup args
|
# setup args
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
|
|
Loading…
Reference in New Issue
Block a user