normalize sampler index by batch size (if not using batched sampler), add option to cap out utterances for a speaker, some other things
This commit is contained in:
parent
6cfdf94bf9
commit
4a71981456
|
@ -171,6 +171,7 @@ class Dataset:
|
|||
cache: bool = True # use diskcache to cache the dataset
|
||||
|
||||
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
||||
max_utterances: int = 0 # max number of utterances a speaker can have (0 to disable)
|
||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
||||
|
||||
sample_type: str = "path" # path | speaker
|
||||
|
|
|
@ -689,6 +689,11 @@ class Dataset(_Dataset):
|
|||
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
|
||||
|
||||
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
|
||||
|
||||
self.duration = 0
|
||||
self.duration_buckets = {}
|
||||
self.current_index = 0
|
||||
self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
|
||||
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
|
@ -701,31 +706,24 @@ class Dataset(_Dataset):
|
|||
|
||||
# dict of paths keyed by speaker names
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||
# do it here due to the above
|
||||
self.duration = 0
|
||||
self.duration_map = _get_duration_map( self.dataset_type )
|
||||
self.duration_buckets = {}
|
||||
|
||||
# cull speakers if they do not have enough utterances
|
||||
if cfg.dataset.min_utterances > 0:
|
||||
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
|
||||
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
|
||||
keys = list(self.paths_by_spkr_name.keys())
|
||||
for key in keys:
|
||||
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
||||
del self.paths_by_spkr_name[key]
|
||||
|
||||
# slice away extraneous utterances
|
||||
if cfg.dataset.max_utterances:
|
||||
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
|
||||
|
||||
# flatten paths
|
||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||
|
||||
# split dataset accordingly per GPU
|
||||
if cfg.distributed and self.training:
|
||||
"""
|
||||
batches = len(self.paths) // world_size()
|
||||
start = batches * global_rank()
|
||||
end = batches * (global_rank() + 1)
|
||||
|
||||
self.paths = self.paths[start:end]
|
||||
"""
|
||||
|
||||
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||
|
||||
# recreate paths_by_spkr_name
|
||||
|
@ -819,9 +817,10 @@ class Dataset(_Dataset):
|
|||
self.sampler = BatchedOrderedSampler(
|
||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
|
||||
max_batch_size=self.batch_size,
|
||||
shuffle=self.sampler_shuffle,
|
||||
)
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||
self.samplers = {}
|
||||
|
@ -1084,6 +1083,8 @@ class Dataset(_Dataset):
|
|||
return prom
|
||||
|
||||
def __getitem__(self, index):
|
||||
self.current_index = index
|
||||
|
||||
if self.empty_text is None:
|
||||
self.empty_text = tokenize(" ")
|
||||
|
||||
|
@ -1385,7 +1386,7 @@ class Dataset(_Dataset):
|
|||
self.training = value
|
||||
|
||||
def index(self):
|
||||
return self.sampler.index() if self.sampler is not None else -1
|
||||
return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size
|
||||
|
||||
def __len__(self):
|
||||
if self.sampler_type == "group":
|
||||
|
@ -1471,8 +1472,8 @@ def create_train_val_dataloader():
|
|||
val_dl = _create_dataloader(val_dataset, training=False)
|
||||
|
||||
_logger.info(str(train_dataset.phone_symmap))
|
||||
_logger.info(str(train_dataset.spkr_symmap))
|
||||
_logger.info(str(train_dataset.spkr_group_symmap))
|
||||
_logger.info(f'#speakers (train): {len(train_dataset.spkr_symmap)}')
|
||||
_logger.info(f'#groups (train): {len(train_dataset.spkr_group_symmap)}')
|
||||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user