normalize sampler index by batch size (if not using batched sampler), add option to cap out utterances for a speaker, some other things
This commit is contained in:
parent
6cfdf94bf9
commit
4a71981456
|
@ -171,6 +171,7 @@ class Dataset:
|
||||||
cache: bool = True # use diskcache to cache the dataset
|
cache: bool = True # use diskcache to cache the dataset
|
||||||
|
|
||||||
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
||||||
|
max_utterances: int = 0 # max number of utterances a speaker can have (0 to disable)
|
||||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
|
|
|
@ -690,6 +690,11 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
|
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
|
||||||
|
|
||||||
|
self.duration = 0
|
||||||
|
self.duration_buckets = {}
|
||||||
|
self.current_index = 0
|
||||||
|
self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
|
||||||
|
|
||||||
# to-do: do not do validation if there's nothing in the validation
|
# to-do: do not do validation if there's nothing in the validation
|
||||||
# this just makes it be happy
|
# this just makes it be happy
|
||||||
if len(self.dataset) == 0:
|
if len(self.dataset) == 0:
|
||||||
|
@ -701,31 +706,24 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# dict of paths keyed by speaker names
|
# dict of paths keyed by speaker names
|
||||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||||
# do it here due to the above
|
|
||||||
self.duration = 0
|
|
||||||
self.duration_map = _get_duration_map( self.dataset_type )
|
self.duration_map = _get_duration_map( self.dataset_type )
|
||||||
self.duration_buckets = {}
|
|
||||||
|
|
||||||
# cull speakers if they do not have enough utterances
|
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
|
||||||
if cfg.dataset.min_utterances > 0:
|
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
|
||||||
keys = list(self.paths_by_spkr_name.keys())
|
keys = list(self.paths_by_spkr_name.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
||||||
del self.paths_by_spkr_name[key]
|
del self.paths_by_spkr_name[key]
|
||||||
|
|
||||||
|
# slice away extraneous utterances
|
||||||
|
if cfg.dataset.max_utterances:
|
||||||
|
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
|
||||||
|
|
||||||
# flatten paths
|
# flatten paths
|
||||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||||
|
|
||||||
# split dataset accordingly per GPU
|
# split dataset accordingly per GPU
|
||||||
if cfg.distributed and self.training:
|
if cfg.distributed and self.training:
|
||||||
"""
|
|
||||||
batches = len(self.paths) // world_size()
|
|
||||||
start = batches * global_rank()
|
|
||||||
end = batches * (global_rank() + 1)
|
|
||||||
|
|
||||||
self.paths = self.paths[start:end]
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||||
|
|
||||||
# recreate paths_by_spkr_name
|
# recreate paths_by_spkr_name
|
||||||
|
@ -819,9 +817,10 @@ class Dataset(_Dataset):
|
||||||
self.sampler = BatchedOrderedSampler(
|
self.sampler = BatchedOrderedSampler(
|
||||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||||
max_duration=cfg.dataset.sample_max_duration_batch,
|
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||||
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
|
max_batch_size=self.batch_size,
|
||||||
shuffle=self.sampler_shuffle,
|
shuffle=self.sampler_shuffle,
|
||||||
)
|
)
|
||||||
|
self.batch_size = 1
|
||||||
else:
|
else:
|
||||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||||
self.samplers = {}
|
self.samplers = {}
|
||||||
|
@ -1084,6 +1083,8 @@ class Dataset(_Dataset):
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
self.current_index = index
|
||||||
|
|
||||||
if self.empty_text is None:
|
if self.empty_text is None:
|
||||||
self.empty_text = tokenize(" ")
|
self.empty_text = tokenize(" ")
|
||||||
|
|
||||||
|
@ -1385,7 +1386,7 @@ class Dataset(_Dataset):
|
||||||
self.training = value
|
self.training = value
|
||||||
|
|
||||||
def index(self):
|
def index(self):
|
||||||
return self.sampler.index() if self.sampler is not None else -1
|
return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self.sampler_type == "group":
|
if self.sampler_type == "group":
|
||||||
|
@ -1471,8 +1472,8 @@ def create_train_val_dataloader():
|
||||||
val_dl = _create_dataloader(val_dataset, training=False)
|
val_dl = _create_dataloader(val_dataset, training=False)
|
||||||
|
|
||||||
_logger.info(str(train_dataset.phone_symmap))
|
_logger.info(str(train_dataset.phone_symmap))
|
||||||
_logger.info(str(train_dataset.spkr_symmap))
|
_logger.info(f'#speakers (train): {len(train_dataset.spkr_symmap)}')
|
||||||
_logger.info(str(train_dataset.spkr_group_symmap))
|
_logger.info(f'#groups (train): {len(train_dataset.spkr_group_symmap)}')
|
||||||
|
|
||||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user