maybe fix evaluation dataset not being capped to cfg.evaluation.size

This commit is contained in:
mrq 2023-08-17 18:56:37 -05:00
parent ee58db746f
commit 3ff7cf8341
2 changed files with 19 additions and 18 deletions

View File

@ -253,6 +253,10 @@ class Dataset(_Dataset):
return prom
@cached_property
def tasks(self):
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
def __getitem__(self, index):
if hasattr(self, "sample_type") and self.sample_type == "speaker":
spkr_name = self.spkrs[index]
@ -274,6 +278,8 @@ class Dataset(_Dataset):
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
resps = _load_quants(path)
task = random.choice(self.tasks)
if task == "tts":
# I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
@ -282,6 +288,7 @@ class Dataset(_Dataset):
path=path,
spkr_name=spkr_name,
spkr_id=spkr_id,
task=task,
text=text,
proms=proms,
resps=resps,
@ -360,26 +367,12 @@ def _load_dataset_paths():
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
# files = data_dir.rglob("*.qnt.pt")
#paths[type].extend([ f'/{type}{_get_hdf5_path( str(file).replace(".qnt.pt", "") )}' for file in files ])
for data_dir in cfg.dataset.training:
get_paths( data_dir, "training" )
for data_dir in cfg.dataset.validation:
get_paths( data_dir, "validation" )
"""
def process( entity ):
if "id" in entity.attrs:
paths[entity.attrs['type']].append( f"{entity.attrs['speaker']}{entity.attrs['id']}" )
return
for child in entity.values():
process( child )
"""
for _, type in enumerate(paths):
dirs = paths[type]
@ -398,6 +391,7 @@ def _load_dataset_paths():
return datasets["training"], datasets["validation"]
# to-do: mirror the hdf5-based load function
def _load_train_val_paths():
paths = []
train_paths = []
@ -432,8 +426,10 @@ def _load_train_val_paths():
train_paths, val_paths = map(sorted, [train_paths, val_paths])
if len(train_paths) == 0:
raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.dataset.training}.")
# to get it to shut up
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
# to-do: allow setting aside a fixed portion of the training dataset for validation
# something like the last X percent of each speaker is set aside
if len(val_paths) == 0:
val_paths = [ train_paths[0] ]

View File

@ -95,6 +95,7 @@ def run_eval(engines, eval_name, dl):
stats['loss'].append(0)
print(str(e))
processed = 0
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
@ -131,6 +132,10 @@ def run_eval(engines, eval_name, dl):
process( name, batch, resps_list )
processed += len(batch["text"])
if processed > cfg.evaluation.size:
break
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats }))