added sample_type that samples from speakers to truly balance an epoch by speakers rather than the entire dataset and a sampler that tries to balance by speakers
This commit is contained in:
parent
599e47a813
commit
44c08d828e
|
@ -129,6 +129,8 @@ class Dataset:
|
||||||
max_prompts: int = 3
|
max_prompts: int = 3
|
||||||
prompt_duration: float = 3.0
|
prompt_duration: float = 3.0
|
||||||
|
|
||||||
|
sample_type: str = "path" # path | speaker
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
name: str = ""
|
name: str = ""
|
||||||
|
|
|
@ -119,6 +119,7 @@ class Dataset(_Dataset):
|
||||||
max_duration=cfg.dataset.duration_range[1],
|
max_duration=cfg.dataset.duration_range[1],
|
||||||
training=False,
|
training=False,
|
||||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||||
|
sample_type=cfg.dataset.sample_type # path | speaker
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._head = None
|
self._head = None
|
||||||
|
@ -126,6 +127,7 @@ class Dataset(_Dataset):
|
||||||
self.max_phones = max_phones
|
self.max_phones = max_phones
|
||||||
self.min_duration = min_duration
|
self.min_duration = min_duration
|
||||||
self.max_duration = max_duration
|
self.max_duration = max_duration
|
||||||
|
self.sample_type = sample_type
|
||||||
|
|
||||||
if cfg.dataset.validate:
|
if cfg.dataset.validate:
|
||||||
self.paths = [
|
self.paths = [
|
||||||
|
@ -167,7 +169,7 @@ class Dataset(_Dataset):
|
||||||
else:
|
else:
|
||||||
self.durations[spkr_id] += duration
|
self.durations[spkr_id] += duration
|
||||||
|
|
||||||
if training and not cfg.distributed:
|
if training and not cfg.distributed and self.sample_type == "path":
|
||||||
self.sampler = Sampler(self.paths, [cfg.get_spkr])
|
self.sampler = Sampler(self.paths, [cfg.get_spkr])
|
||||||
else:
|
else:
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
@ -200,10 +202,14 @@ class Dataset(_Dataset):
|
||||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
||||||
choices = [*choices]
|
choices = [*choices]
|
||||||
|
|
||||||
|
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
|
||||||
if len(choices) == 0:
|
if len(choices) == 0:
|
||||||
|
choices = [*set(self.paths_by_spkr_name[spkr_name])]
|
||||||
|
"""
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to find another different utterance for {spkr_name}."
|
f"Failed to find another different utterance for {spkr_name}."
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
# shuffle it up a bit
|
# shuffle it up a bit
|
||||||
offset = random.randint(-16, 16)
|
offset = random.randint(-16, 16)
|
||||||
|
@ -248,13 +254,17 @@ class Dataset(_Dataset):
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
if self.training and self.sampler is not None:
|
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
||||||
path = self.sampler.sample()
|
spkr_name = self.spkrs[index]
|
||||||
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
|
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
||||||
else:
|
else:
|
||||||
path = self.paths[index]
|
if self.training and self.sampler is not None:
|
||||||
|
path = self.sampler.sample()
|
||||||
spkr_name = cfg.get_spkr(path)
|
else:
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
path = self.paths[index]
|
||||||
|
spkr_name = cfg.get_spkr(path)
|
||||||
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
@ -287,6 +297,8 @@ class Dataset(_Dataset):
|
||||||
self.paths = [*_interleaved_reorder(self.paths, fn)]
|
self.paths = [*_interleaved_reorder(self.paths, fn)]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
||||||
|
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||||
return min(len(self.paths), self._head or len(self.paths))
|
return min(len(self.paths), self._head or len(self.paths))
|
||||||
|
|
||||||
def pin_memory(self):
|
def pin_memory(self):
|
||||||
|
@ -444,6 +456,7 @@ def create_datasets():
|
||||||
|
|
||||||
def create_train_val_dataloader():
|
def create_train_val_dataloader():
|
||||||
train_dataset, val_dataset = create_datasets()
|
train_dataset, val_dataset = create_datasets()
|
||||||
|
#train_dataset.sample_type = "speaker"
|
||||||
|
|
||||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||||
subtrain_dataset.head_(cfg.evaluation.size)
|
subtrain_dataset.head_(cfg.evaluation.size)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user