added sample_type that samples from speakers to truly balance an epoch by speakers rather than the entire dataset and a sampler that tries to balance by speakers

This commit is contained in:
mrq 2023-08-16 19:39:21 -05:00
parent 599e47a813
commit 44c08d828e
2 changed files with 22 additions and 7 deletions

View File

@ -129,6 +129,8 @@ class Dataset:
max_prompts: int = 3 max_prompts: int = 3
prompt_duration: float = 3.0 prompt_duration: float = 3.0
sample_type: str = "path" # path | speaker
@dataclass() @dataclass()
class Model: class Model:
name: str = "" name: str = ""

View File

@ -119,6 +119,7 @@ class Dataset(_Dataset):
max_duration=cfg.dataset.duration_range[1], max_duration=cfg.dataset.duration_range[1],
training=False, training=False,
extra_paths_by_spkr_name: dict[str, list] = {}, extra_paths_by_spkr_name: dict[str, list] = {},
sample_type=cfg.dataset.sample_type # path | speaker
): ):
super().__init__() super().__init__()
self._head = None self._head = None
@ -126,6 +127,7 @@ class Dataset(_Dataset):
self.max_phones = max_phones self.max_phones = max_phones
self.min_duration = min_duration self.min_duration = min_duration
self.max_duration = max_duration self.max_duration = max_duration
self.sample_type = sample_type
if cfg.dataset.validate: if cfg.dataset.validate:
self.paths = [ self.paths = [
@ -167,7 +169,7 @@ class Dataset(_Dataset):
else: else:
self.durations[spkr_id] += duration self.durations[spkr_id] += duration
if training and not cfg.distributed: if training and not cfg.distributed and self.sample_type == "path":
self.sampler = Sampler(self.paths, [cfg.get_spkr]) self.sampler = Sampler(self.paths, [cfg.get_spkr])
else: else:
self.sampler = None self.sampler = None
@ -200,10 +202,14 @@ class Dataset(_Dataset):
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore} choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
choices = [*choices] choices = [*choices]
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
if len(choices) == 0: if len(choices) == 0:
choices = [*set(self.paths_by_spkr_name[spkr_name])]
"""
raise ValueError( raise ValueError(
f"Failed to find another different utterance for {spkr_name}." f"Failed to find another different utterance for {spkr_name}."
) )
"""
# shuffle it up a bit # shuffle it up a bit
offset = random.randint(-16, 16) offset = random.randint(-16, 16)
@ -248,13 +254,17 @@ class Dataset(_Dataset):
return prom return prom
def __getitem__(self, index): def __getitem__(self, index):
if self.training and self.sampler is not None: if hasattr(self, "sample_type") and self.sample_type == "speaker":
path = self.sampler.sample() spkr_name = self.spkrs[index]
spkr_id = self.spkr_symmap[spkr_name]
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
else: else:
path = self.paths[index] if self.training and self.sampler is not None:
path = self.sampler.sample()
spkr_name = cfg.get_spkr(path) else:
spkr_id = self.spkr_symmap[spkr_name] path = self.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path) key = _get_hdf5_path(path)
@ -287,6 +297,8 @@ class Dataset(_Dataset):
self.paths = [*_interleaved_reorder(self.paths, fn)] self.paths = [*_interleaved_reorder(self.paths, fn)]
def __len__(self): def __len__(self):
if hasattr(self, "sample_type") and self.sample_type == "speaker":
return min(len(self.spkrs), self._head or len(self.spkrs))
return min(len(self.paths), self._head or len(self.paths)) return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self): def pin_memory(self):
@ -444,6 +456,7 @@ def create_datasets():
def create_train_val_dataloader(): def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets() train_dataset, val_dataset = create_datasets()
#train_dataset.sample_type = "speaker"
subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size) subtrain_dataset.head_(cfg.evaluation.size)