normalize sampler index by batch size (if not using batched sampler), add option to cap out utterances for a speaker, some other things

This commit is contained in:
mrq 2024-11-18 12:46:50 -06:00
parent 6cfdf94bf9
commit 4a71981456
2 changed files with 19 additions and 17 deletions

View File

@ -171,6 +171,7 @@ class Dataset:
cache: bool = True # use diskcache to cache the dataset
min_utterances: int = 2 # minimum number of utterances a speaker can have
max_utterances: int = 0 # max number of utterances a speaker can have (0 to disable)
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
sample_type: str = "path" # path | speaker

View File

@ -689,6 +689,11 @@ class Dataset(_Dataset):
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
self.duration = 0
self.duration_buckets = {}
self.current_index = 0
self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
@ -701,31 +706,24 @@ class Dataset(_Dataset):
# dict of paths keyed by speaker names
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
# do it here due to the above
self.duration = 0
self.duration_map = _get_duration_map( self.dataset_type )
self.duration_buckets = {}
# cull speakers if they do not have enough utterances
if cfg.dataset.min_utterances > 0:
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
keys = list(self.paths_by_spkr_name.keys())
for key in keys:
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
del self.paths_by_spkr_name[key]
# slice away extraneous utterances
if cfg.dataset.max_utterances:
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
# flatten paths
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
# split dataset accordingly per GPU
if cfg.distributed and self.training:
"""
batches = len(self.paths) // world_size()
start = batches * global_rank()
end = batches * (global_rank() + 1)
self.paths = self.paths[start:end]
"""
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
# recreate paths_by_spkr_name
@ -819,9 +817,10 @@ class Dataset(_Dataset):
self.sampler = BatchedOrderedSampler(
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
max_duration=cfg.dataset.sample_max_duration_batch,
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
max_batch_size=self.batch_size,
shuffle=self.sampler_shuffle,
)
self.batch_size = 1
else:
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
self.samplers = {}
@ -1084,6 +1083,8 @@ class Dataset(_Dataset):
return prom
def __getitem__(self, index):
self.current_index = index
if self.empty_text is None:
self.empty_text = tokenize(" ")
@ -1385,7 +1386,7 @@ class Dataset(_Dataset):
self.training = value
def index(self):
return self.sampler.index() if self.sampler is not None else -1
return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size
def __len__(self):
if self.sampler_type == "group":
@ -1471,8 +1472,8 @@ def create_train_val_dataloader():
val_dl = _create_dataloader(val_dataset, training=False)
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
_logger.info(str(train_dataset.spkr_group_symmap))
_logger.info(f'#speakers (train): {len(train_dataset.spkr_symmap)}')
_logger.info(f'#groups (train): {len(train_dataset.spkr_group_symmap)}')
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")