added sample_type that samples from speakers to truly balance an epoch by speakers rather than the entire dataset and a sampler that tries to balance by speakers
This commit is contained in:
parent
599e47a813
commit
44c08d828e
|
@ -129,6 +129,8 @@ class Dataset:
|
|||
max_prompts: int = 3
|
||||
prompt_duration: float = 3.0
|
||||
|
||||
sample_type: str = "path" # path | speaker
|
||||
|
||||
@dataclass()
|
||||
class Model:
|
||||
name: str = ""
|
||||
|
|
|
@ -119,6 +119,7 @@ class Dataset(_Dataset):
|
|||
max_duration=cfg.dataset.duration_range[1],
|
||||
training=False,
|
||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||
sample_type=cfg.dataset.sample_type # path | speaker
|
||||
):
|
||||
super().__init__()
|
||||
self._head = None
|
||||
|
@ -126,6 +127,7 @@ class Dataset(_Dataset):
|
|||
self.max_phones = max_phones
|
||||
self.min_duration = min_duration
|
||||
self.max_duration = max_duration
|
||||
self.sample_type = sample_type
|
||||
|
||||
if cfg.dataset.validate:
|
||||
self.paths = [
|
||||
|
@ -167,7 +169,7 @@ class Dataset(_Dataset):
|
|||
else:
|
||||
self.durations[spkr_id] += duration
|
||||
|
||||
if training and not cfg.distributed:
|
||||
if training and not cfg.distributed and self.sample_type == "path":
|
||||
self.sampler = Sampler(self.paths, [cfg.get_spkr])
|
||||
else:
|
||||
self.sampler = None
|
||||
|
@ -200,10 +202,14 @@ class Dataset(_Dataset):
|
|||
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
||||
choices = [*choices]
|
||||
|
||||
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
|
||||
if len(choices) == 0:
|
||||
choices = [*set(self.paths_by_spkr_name[spkr_name])]
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Failed to find another different utterance for {spkr_name}."
|
||||
)
|
||||
"""
|
||||
|
||||
# shuffle it up a bit
|
||||
offset = random.randint(-16, 16)
|
||||
|
@ -248,13 +254,17 @@ class Dataset(_Dataset):
|
|||
return prom
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.training and self.sampler is not None:
|
||||
path = self.sampler.sample()
|
||||
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
||||
spkr_name = self.spkrs[index]
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
||||
else:
|
||||
path = self.paths[index]
|
||||
|
||||
spkr_name = cfg.get_spkr(path)
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
if self.training and self.sampler is not None:
|
||||
path = self.sampler.sample()
|
||||
else:
|
||||
path = self.paths[index]
|
||||
spkr_name = cfg.get_spkr(path)
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
@ -287,6 +297,8 @@ class Dataset(_Dataset):
|
|||
self.paths = [*_interleaved_reorder(self.paths, fn)]
|
||||
|
||||
def __len__(self):
|
||||
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
||||
def pin_memory(self):
|
||||
|
@ -444,6 +456,7 @@ def create_datasets():
|
|||
|
||||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
#train_dataset.sample_type = "speaker"
|
||||
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
subtrain_dataset.head_(cfg.evaluation.size)
|
||||
|
|
Loading…
Reference in New Issue
Block a user