add shuffle to samplers that can support it
This commit is contained in:
parent
396af541c5
commit
312a8e3ead
|
@ -157,9 +157,10 @@ class Dataset:
|
|||
p_resp_append: float = 1.0
|
||||
|
||||
sample_type: str = "path" # path | speaker
|
||||
sample_order: str = "shuffle" # duration
|
||||
sample_order: str = "interleaved" # duration
|
||||
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||
sample_shuffle: bool = True #
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
|
||||
|
|
|
@ -424,7 +424,6 @@ class Dataset(_Dataset):
|
|||
):
|
||||
super().__init__()
|
||||
self._head = None
|
||||
self.shuffle = False
|
||||
self.sampler = None
|
||||
|
||||
self.paths = []
|
||||
|
@ -434,6 +433,7 @@ class Dataset(_Dataset):
|
|||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
|
||||
self.sampler_order = cfg.dataset.sample_order
|
||||
self.sampler_shuffle = cfg.dataset.sample_shuffle
|
||||
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
|
@ -510,7 +510,7 @@ class Dataset(_Dataset):
|
|||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
||||
# flatten paths
|
||||
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
||||
elif self.sampler_order == "shuffle":
|
||||
else:
|
||||
# just interleave
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
|
@ -547,29 +547,28 @@ class Dataset(_Dataset):
|
|||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
sampler_path = cfg.rel_path / self.sampler_state_dict_path
|
||||
|
||||
if self.sampler_type == "path":
|
||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
self.sampler = BatchedOrderedSampler(
|
||||
self.duration_buckets if not sampler_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||
cfg.dataset.sample_max_duration_batch,
|
||||
cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
|
||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
|
||||
shuffle=self.sampler_shuffle
|
||||
)
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self) )
|
||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||
self.samplers = {}
|
||||
self.spkr_samplers = {}
|
||||
else:
|
||||
self.sampler = RandomSampler( len(self) )
|
||||
self.samplers = { name: PoolSampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
self.load_state_dict()
|
||||
|
||||
@cached_property
|
||||
def sampler_state_dict_path(self):
|
||||
return f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
return cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
|
||||
def get_speaker(self, path):
|
||||
if isinstance(path, str):
|
||||
|
@ -602,7 +601,7 @@ class Dataset(_Dataset):
|
|||
|
||||
def save_state_dict(self, path = None):
|
||||
if path is None:
|
||||
path = cfg.rel_path / self.sampler_state_dict_path
|
||||
path = self.sampler_state_dict_path
|
||||
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.get_state()
|
||||
|
@ -615,7 +614,7 @@ class Dataset(_Dataset):
|
|||
|
||||
def load_state_dict(self, path = None):
|
||||
if path is None:
|
||||
path = cfg.rel_path / self.sampler_state_dict_path
|
||||
path = self.sampler_state_dict_path
|
||||
|
||||
if not path.exists():
|
||||
return
|
||||
|
@ -652,13 +651,6 @@ class Dataset(_Dataset):
|
|||
def _get_task_symmap(self):
|
||||
return get_task_symmap()
|
||||
|
||||
"""
|
||||
def get_task_token( self, token, levels=cfg.model.max_levels ):
|
||||
if not hasattr(self, "task_symmap"):
|
||||
self.task_symmap = self._get_task_symmap()
|
||||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
|
||||
"""
|
||||
|
||||
def sample_noise(self):
|
||||
path = random.choice(self.noise_paths)
|
||||
|
||||
|
@ -756,12 +748,13 @@ class Dataset(_Dataset):
|
|||
else:
|
||||
resps, metadata = _load_quants(path, return_metadata=True)
|
||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||
#text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
|
||||
|
||||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||
|
||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||
"""
|
||||
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
|
||||
# it probably is better to pad with silence instead of just stitching utterances and ruining things
|
||||
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
|
||||
|
||||
|
@ -997,13 +990,6 @@ class Dataset(_Dataset):
|
|||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
||||
def pin_memory(self):
|
||||
self.text = self.text.pin_memory()
|
||||
self.proms = self.proms.pin_memory()
|
||||
self.resps = self.resps.pin_memory()
|
||||
self.resp = self.resp.pin_memory()
|
||||
return self
|
||||
|
||||
|
||||
def collate_fn(samples: list[dict]):
|
||||
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
||||
|
@ -1017,14 +1003,8 @@ def _seed_worker(worker_id):
|
|||
|
||||
|
||||
def _create_dataloader(dataset, training):
|
||||
"""
|
||||
if cfg.distributed and training:
|
||||
sampler = DistributedSampler(dataset)
|
||||
shuffle = False
|
||||
"""
|
||||
|
||||
kwargs = dict(
|
||||
shuffle=dataset.shuffle,
|
||||
shuffle=False,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
drop_last=training,
|
||||
sampler=dataset.sampler,
|
||||
|
@ -1037,7 +1017,7 @@ def _create_dataloader(dataset, training):
|
|||
num_workers=cfg.dataset.workers,
|
||||
collate_fn=collate_fn,
|
||||
persistent_workers=cfg.dataset.workers > 1,
|
||||
pin_memory=False, # True,
|
||||
pin_memory=False,
|
||||
worker_init_fn=_seed_worker,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -9,14 +9,17 @@ from .distributed import global_rank, local_rank, world_size
|
|||
|
||||
# Randomly picks an index from an array of indices
|
||||
class PoolSampler():
|
||||
def __init__( self, pool = [], keep_all = False ):
|
||||
def __init__( self, pool = [], keep_all = False, shuffle = False ):
|
||||
self.length = len(pool)
|
||||
self.shuffle = shuffle
|
||||
self.global_pool = pool if keep_all else None
|
||||
self.global_indices = [ i for i in range(self.length) ]
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.current_pool = [ i for i in self.global_indices ]
|
||||
if self.shuffle:
|
||||
random(self.current_pool)
|
||||
|
||||
def sample(self, pool = None):
|
||||
if pool is None:
|
||||
|
@ -78,9 +81,10 @@ class OrderedSampler(Sampler):
|
|||
|
||||
# Like the above, but will batch based on token count
|
||||
class BatchedOrderedSampler(Sampler):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0 ):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
|
||||
self.position = 0
|
||||
self.batches = []
|
||||
self.shuffle = shuffle
|
||||
|
||||
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||
|
||||
|
@ -105,12 +109,17 @@ class BatchedOrderedSampler(Sampler):
|
|||
current_index += 1
|
||||
current_size += duration
|
||||
|
||||
if self.shuffle:
|
||||
random.shuffle(self.batches)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batches)
|
||||
|
||||
def __iter__(self):
|
||||
if self.position >= len(self.batches):
|
||||
self.position = 0
|
||||
if self.shuffle:
|
||||
random.shuffle(self.batches)
|
||||
|
||||
while self.position < len(self.batches):
|
||||
yield self.batches[self.position]
|
||||
|
|
Loading…
Reference in New Issue
Block a user