add shuffle to samplers that can support it

This commit is contained in:
mrq 2024-06-30 11:36:46 -05:00
parent 396af541c5
commit 312a8e3ead
3 changed files with 29 additions and 39 deletions

View File

@ -157,9 +157,10 @@ class Dataset:
p_resp_append: float = 1.0
sample_type: str = "path" # path | speaker
sample_order: str = "shuffle" # duration
sample_order: str = "interleaved" # duration
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
sample_shuffle: bool = True #
tasks_list: list[str] = field(default_factory=lambda: ["tts"])

View File

@ -424,7 +424,6 @@ class Dataset(_Dataset):
):
super().__init__()
self._head = None
self.shuffle = False
self.sampler = None
self.paths = []
@ -434,6 +433,7 @@ class Dataset(_Dataset):
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
self.sampler_order = cfg.dataset.sample_order
self.sampler_shuffle = cfg.dataset.sample_shuffle
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
@ -510,7 +510,7 @@ class Dataset(_Dataset):
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
# flatten paths
self.paths = list(itertools.chain.from_iterable(flattened.values()))
elif self.sampler_order == "shuffle":
else:
# just interleave
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
@ -547,29 +547,28 @@ class Dataset(_Dataset):
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
sampler_path = cfg.rel_path / self.sampler_state_dict_path
if self.sampler_type == "path":
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler(
self.duration_buckets if not sampler_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
cfg.dataset.sample_max_duration_batch,
cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
max_duration=cfg.dataset.sample_max_duration_batch,
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
shuffle=self.sampler_shuffle
)
else:
self.sampler = OrderedSampler( len(self) )
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
self.samplers = {}
self.spkr_samplers = {}
else:
self.sampler = RandomSampler( len(self) )
self.samplers = { name: PoolSampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
self.load_state_dict()
@cached_property
def sampler_state_dict_path(self):
return f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
return cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
def get_speaker(self, path):
if isinstance(path, str):
@ -602,7 +601,7 @@ class Dataset(_Dataset):
def save_state_dict(self, path = None):
if path is None:
path = cfg.rel_path / self.sampler_state_dict_path
path = self.sampler_state_dict_path
if self.sampler_type == "path":
state_dict = self.sampler.get_state()
@ -615,7 +614,7 @@ class Dataset(_Dataset):
def load_state_dict(self, path = None):
if path is None:
path = cfg.rel_path / self.sampler_state_dict_path
path = self.sampler_state_dict_path
if not path.exists():
return
@ -652,13 +651,6 @@ class Dataset(_Dataset):
def _get_task_symmap(self):
return get_task_symmap()
"""
def get_task_token( self, token, levels=cfg.model.max_levels ):
if not hasattr(self, "task_symmap"):
self.task_symmap = self._get_task_symmap()
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
"""
def sample_noise(self):
path = random.choice(self.noise_paths)
@ -756,12 +748,13 @@ class Dataset(_Dataset):
else:
resps, metadata = _load_quants(path, return_metadata=True)
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
#text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
# append additional prompts in an attempt to artifically increase lengths / offer new data
"""
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
# it probably is better to pad with silence instead of just stitching utterances and ruining things
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
@ -997,13 +990,6 @@ class Dataset(_Dataset):
return min(len(self.spkrs), self._head or len(self.spkrs))
return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self):
self.text = self.text.pin_memory()
self.proms = self.proms.pin_memory()
self.resps = self.resps.pin_memory()
self.resp = self.resp.pin_memory()
return self
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
@ -1017,14 +1003,8 @@ def _seed_worker(worker_id):
def _create_dataloader(dataset, training):
"""
if cfg.distributed and training:
sampler = DistributedSampler(dataset)
shuffle = False
"""
kwargs = dict(
shuffle=dataset.shuffle,
shuffle=False,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
drop_last=training,
sampler=dataset.sampler,
@ -1037,7 +1017,7 @@ def _create_dataloader(dataset, training):
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 1,
pin_memory=False, # True,
pin_memory=False,
worker_init_fn=_seed_worker,
**kwargs,
)

View File

@ -9,14 +9,17 @@ from .distributed import global_rank, local_rank, world_size
# Randomly picks an index from an array of indices
class PoolSampler():
def __init__( self, pool = [], keep_all = False ):
def __init__( self, pool = [], keep_all = False, shuffle = False ):
self.length = len(pool)
self.shuffle = shuffle
self.global_pool = pool if keep_all else None
self.global_indices = [ i for i in range(self.length) ]
self.reset()
def reset(self):
self.current_pool = [ i for i in self.global_indices ]
if self.shuffle:
random(self.current_pool)
def sample(self, pool = None):
if pool is None:
@ -78,9 +81,10 @@ class OrderedSampler(Sampler):
# Like the above, but will batch based on token count
class BatchedOrderedSampler(Sampler):
def __init__( self, buckets, max_duration=0, max_batch_size=0 ):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
self.position = 0
self.batches = []
self.shuffle = shuffle
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
@ -105,12 +109,17 @@ class BatchedOrderedSampler(Sampler):
current_index += 1
current_size += duration
if self.shuffle:
random.shuffle(self.batches)
def __len__(self):
return len(self.batches)
def __iter__(self):
if self.position >= len(self.batches):
self.position = 0
if self.shuffle:
random.shuffle(self.batches)
while self.position < len(self.batches):
yield self.batches[self.position]