maybe fix evaluation dataset not being capped to cfg.evaluation.size
This commit is contained in:
parent
ee58db746f
commit
3ff7cf8341
|
@ -253,6 +253,10 @@ class Dataset(_Dataset):
|
|||
|
||||
return prom
|
||||
|
||||
@cached_property
|
||||
def tasks(self):
|
||||
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
||||
|
||||
def __getitem__(self, index):
|
||||
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
||||
spkr_name = self.spkrs[index]
|
||||
|
@ -274,6 +278,8 @@ class Dataset(_Dataset):
|
|||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
|
||||
task = random.choice(self.tasks)
|
||||
if task == "tts":
|
||||
# I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
|
||||
|
@ -282,6 +288,7 @@ class Dataset(_Dataset):
|
|||
path=path,
|
||||
spkr_name=spkr_name,
|
||||
spkr_id=spkr_id,
|
||||
task=task,
|
||||
text=text,
|
||||
proms=proms,
|
||||
resps=resps,
|
||||
|
@ -360,26 +367,12 @@ def _load_dataset_paths():
|
|||
|
||||
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
|
||||
|
||||
# files = data_dir.rglob("*.qnt.pt")
|
||||
#paths[type].extend([ f'/{type}{_get_hdf5_path( str(file).replace(".qnt.pt", "") )}' for file in files ])
|
||||
|
||||
for data_dir in cfg.dataset.training:
|
||||
get_paths( data_dir, "training" )
|
||||
|
||||
for data_dir in cfg.dataset.validation:
|
||||
get_paths( data_dir, "validation" )
|
||||
|
||||
"""
|
||||
def process( entity ):
|
||||
if "id" in entity.attrs:
|
||||
paths[entity.attrs['type']].append( f"{entity.attrs['speaker']}{entity.attrs['id']}" )
|
||||
return
|
||||
|
||||
for child in entity.values():
|
||||
process( child )
|
||||
"""
|
||||
|
||||
|
||||
for _, type in enumerate(paths):
|
||||
dirs = paths[type]
|
||||
|
||||
|
@ -398,6 +391,7 @@ def _load_dataset_paths():
|
|||
|
||||
return datasets["training"], datasets["validation"]
|
||||
|
||||
# to-do: mirror the hdf5-based load function
|
||||
def _load_train_val_paths():
|
||||
paths = []
|
||||
train_paths = []
|
||||
|
@ -432,8 +426,10 @@ def _load_train_val_paths():
|
|||
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
||||
|
||||
if len(train_paths) == 0:
|
||||
raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.dataset.training}.")
|
||||
# to get it to shut up
|
||||
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
|
||||
|
||||
# to-do: allow setting aside a fixed portion of the training dataset for validation
|
||||
# something like the last X percent of each speaker is set aside
|
||||
if len(val_paths) == 0:
|
||||
val_paths = [ train_paths[0] ]
|
||||
|
||||
|
|
|
@ -95,6 +95,7 @@ def run_eval(engines, eval_name, dl):
|
|||
stats['loss'].append(0)
|
||||
print(str(e))
|
||||
|
||||
processed = 0
|
||||
for batch in tqdm(dl):
|
||||
batch: dict = to_device(batch, cfg.device)
|
||||
|
||||
|
@ -131,6 +132,10 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
processed += len(batch["text"])
|
||||
if processed > cfg.evaluation.size:
|
||||
break
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
engines_stats.update(flatten_dict({ name: stats }))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user