maybe fix evaluation dataset not being capped to cfg.evaluation.size
This commit is contained in:
parent
ee58db746f
commit
3ff7cf8341
|
@ -253,6 +253,10 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def tasks(self):
|
||||||
|
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
||||||
spkr_name = self.spkrs[index]
|
spkr_name = self.spkrs[index]
|
||||||
|
@ -274,14 +278,17 @@ class Dataset(_Dataset):
|
||||||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||||
resps = _load_quants(path)
|
resps = _load_quants(path)
|
||||||
|
|
||||||
# I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing
|
task = random.choice(self.tasks)
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
if task == "tts":
|
||||||
|
# I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing
|
||||||
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
path=path,
|
path=path,
|
||||||
spkr_name=spkr_name,
|
spkr_name=spkr_name,
|
||||||
spkr_id=spkr_id,
|
spkr_id=spkr_id,
|
||||||
|
task=task,
|
||||||
text=text,
|
text=text,
|
||||||
proms=proms,
|
proms=proms,
|
||||||
resps=resps,
|
resps=resps,
|
||||||
|
@ -360,26 +367,12 @@ def _load_dataset_paths():
|
||||||
|
|
||||||
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
|
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
|
||||||
|
|
||||||
# files = data_dir.rglob("*.qnt.pt")
|
|
||||||
#paths[type].extend([ f'/{type}{_get_hdf5_path( str(file).replace(".qnt.pt", "") )}' for file in files ])
|
|
||||||
|
|
||||||
for data_dir in cfg.dataset.training:
|
for data_dir in cfg.dataset.training:
|
||||||
get_paths( data_dir, "training" )
|
get_paths( data_dir, "training" )
|
||||||
|
|
||||||
for data_dir in cfg.dataset.validation:
|
for data_dir in cfg.dataset.validation:
|
||||||
get_paths( data_dir, "validation" )
|
get_paths( data_dir, "validation" )
|
||||||
|
|
||||||
"""
|
|
||||||
def process( entity ):
|
|
||||||
if "id" in entity.attrs:
|
|
||||||
paths[entity.attrs['type']].append( f"{entity.attrs['speaker']}{entity.attrs['id']}" )
|
|
||||||
return
|
|
||||||
|
|
||||||
for child in entity.values():
|
|
||||||
process( child )
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
for _, type in enumerate(paths):
|
for _, type in enumerate(paths):
|
||||||
dirs = paths[type]
|
dirs = paths[type]
|
||||||
|
|
||||||
|
@ -398,6 +391,7 @@ def _load_dataset_paths():
|
||||||
|
|
||||||
return datasets["training"], datasets["validation"]
|
return datasets["training"], datasets["validation"]
|
||||||
|
|
||||||
|
# to-do: mirror the hdf5-based load function
|
||||||
def _load_train_val_paths():
|
def _load_train_val_paths():
|
||||||
paths = []
|
paths = []
|
||||||
train_paths = []
|
train_paths = []
|
||||||
|
@ -432,8 +426,10 @@ def _load_train_val_paths():
|
||||||
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
||||||
|
|
||||||
if len(train_paths) == 0:
|
if len(train_paths) == 0:
|
||||||
raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.dataset.training}.")
|
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
|
||||||
# to get it to shut up
|
|
||||||
|
# to-do: allow setting aside a fixed portion of the training dataset for validation
|
||||||
|
# something like the last X percent of each speaker is set aside
|
||||||
if len(val_paths) == 0:
|
if len(val_paths) == 0:
|
||||||
val_paths = [ train_paths[0] ]
|
val_paths = [ train_paths[0] ]
|
||||||
|
|
||||||
|
|
|
@ -95,6 +95,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
stats['loss'].append(0)
|
stats['loss'].append(0)
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
|
processed = 0
|
||||||
for batch in tqdm(dl):
|
for batch in tqdm(dl):
|
||||||
batch: dict = to_device(batch, cfg.device)
|
batch: dict = to_device(batch, cfg.device)
|
||||||
|
|
||||||
|
@ -131,6 +132,10 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
|
processed += len(batch["text"])
|
||||||
|
if processed > cfg.evaluation.size:
|
||||||
|
break
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||||
engines_stats.update(flatten_dict({ name: stats }))
|
engines_stats.update(flatten_dict({ name: stats }))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user