added sample_type that samples from speakers to truly balance an epoch by speakers rather than the entire dataset and a sampler that tries to balance by speakers

This commit is contained in:
mrq 2023-08-16 19:39:21 -05:00
parent 599e47a813
commit 44c08d828e
2 changed files with 22 additions and 7 deletions

View File

@ -129,6 +129,8 @@ class Dataset:
max_prompts: int = 3
prompt_duration: float = 3.0
sample_type: str = "path" # path | speaker
@dataclass()
class Model:
name: str = ""

View File

@ -119,6 +119,7 @@ class Dataset(_Dataset):
max_duration=cfg.dataset.duration_range[1],
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
sample_type=cfg.dataset.sample_type # path | speaker
):
super().__init__()
self._head = None
@ -126,6 +127,7 @@ class Dataset(_Dataset):
self.max_phones = max_phones
self.min_duration = min_duration
self.max_duration = max_duration
self.sample_type = sample_type
if cfg.dataset.validate:
self.paths = [
@ -167,7 +169,7 @@ class Dataset(_Dataset):
else:
self.durations[spkr_id] += duration
if training and not cfg.distributed:
if training and not cfg.distributed and self.sample_type == "path":
self.sampler = Sampler(self.paths, [cfg.get_spkr])
else:
self.sampler = None
@ -200,10 +202,14 @@ class Dataset(_Dataset):
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
choices = [*choices]
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
if len(choices) == 0:
choices = [*set(self.paths_by_spkr_name[spkr_name])]
"""
raise ValueError(
f"Failed to find another different utterance for {spkr_name}."
)
"""
# shuffle it up a bit
offset = random.randint(-16, 16)
@ -248,13 +254,17 @@ class Dataset(_Dataset):
return prom
def __getitem__(self, index):
if self.training and self.sampler is not None:
path = self.sampler.sample()
if hasattr(self, "sample_type") and self.sample_type == "speaker":
spkr_name = self.spkrs[index]
spkr_id = self.spkr_symmap[spkr_name]
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
else:
path = self.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name]
if self.training and self.sampler is not None:
path = self.sampler.sample()
else:
path = self.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
@ -287,6 +297,8 @@ class Dataset(_Dataset):
self.paths = [*_interleaved_reorder(self.paths, fn)]
def __len__(self):
if hasattr(self, "sample_type") and self.sample_type == "speaker":
return min(len(self.spkrs), self._head or len(self.spkrs))
return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self):
@ -444,6 +456,7 @@ def create_datasets():
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
#train_dataset.sample_type = "speaker"
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)