add shuffle to samplers that can support it
This commit is contained in:
parent
396af541c5
commit
312a8e3ead
|
@ -157,9 +157,10 @@ class Dataset:
|
||||||
p_resp_append: float = 1.0
|
p_resp_append: float = 1.0
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
sample_order: str = "shuffle" # duration
|
sample_order: str = "interleaved" # duration
|
||||||
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||||
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||||
|
sample_shuffle: bool = True #
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
|
|
||||||
|
|
|
@ -424,7 +424,6 @@ class Dataset(_Dataset):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._head = None
|
self._head = None
|
||||||
self.shuffle = False
|
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
|
||||||
self.paths = []
|
self.paths = []
|
||||||
|
@ -434,6 +433,7 @@ class Dataset(_Dataset):
|
||||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||||
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
|
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
|
||||||
self.sampler_order = cfg.dataset.sample_order
|
self.sampler_order = cfg.dataset.sample_order
|
||||||
|
self.sampler_shuffle = cfg.dataset.sample_shuffle
|
||||||
|
|
||||||
# to-do: do not do validation if there's nothing in the validation
|
# to-do: do not do validation if there's nothing in the validation
|
||||||
# this just makes it be happy
|
# this just makes it be happy
|
||||||
|
@ -510,7 +510,7 @@ class Dataset(_Dataset):
|
||||||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
||||||
# flatten paths
|
# flatten paths
|
||||||
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
||||||
elif self.sampler_order == "shuffle":
|
else:
|
||||||
# just interleave
|
# just interleave
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
|
||||||
|
@ -547,29 +547,28 @@ class Dataset(_Dataset):
|
||||||
if len(self.paths) == 0:
|
if len(self.paths) == 0:
|
||||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||||
|
|
||||||
sampler_path = cfg.rel_path / self.sampler_state_dict_path
|
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||||
self.sampler = BatchedOrderedSampler(
|
self.sampler = BatchedOrderedSampler(
|
||||||
self.duration_buckets if not sampler_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||||
cfg.dataset.sample_max_duration_batch,
|
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||||
cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
|
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
|
||||||
|
shuffle=self.sampler_shuffle
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.sampler = OrderedSampler( len(self) )
|
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||||
self.samplers = {}
|
self.samplers = {}
|
||||||
self.spkr_samplers = {}
|
self.spkr_samplers = {}
|
||||||
else:
|
else:
|
||||||
self.sampler = RandomSampler( len(self) )
|
self.sampler = RandomSampler( len(self) )
|
||||||
self.samplers = { name: PoolSampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
||||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
self.load_state_dict()
|
self.load_state_dict()
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sampler_state_dict_path(self):
|
def sampler_state_dict_path(self):
|
||||||
return f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
return cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||||
|
|
||||||
def get_speaker(self, path):
|
def get_speaker(self, path):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
|
@ -602,7 +601,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
def save_state_dict(self, path = None):
|
def save_state_dict(self, path = None):
|
||||||
if path is None:
|
if path is None:
|
||||||
path = cfg.rel_path / self.sampler_state_dict_path
|
path = self.sampler_state_dict_path
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
state_dict = self.sampler.get_state()
|
state_dict = self.sampler.get_state()
|
||||||
|
@ -615,7 +614,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
def load_state_dict(self, path = None):
|
def load_state_dict(self, path = None):
|
||||||
if path is None:
|
if path is None:
|
||||||
path = cfg.rel_path / self.sampler_state_dict_path
|
path = self.sampler_state_dict_path
|
||||||
|
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return
|
return
|
||||||
|
@ -652,13 +651,6 @@ class Dataset(_Dataset):
|
||||||
def _get_task_symmap(self):
|
def _get_task_symmap(self):
|
||||||
return get_task_symmap()
|
return get_task_symmap()
|
||||||
|
|
||||||
"""
|
|
||||||
def get_task_token( self, token, levels=cfg.model.max_levels ):
|
|
||||||
if not hasattr(self, "task_symmap"):
|
|
||||||
self.task_symmap = self._get_task_symmap()
|
|
||||||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def sample_noise(self):
|
def sample_noise(self):
|
||||||
path = random.choice(self.noise_paths)
|
path = random.choice(self.noise_paths)
|
||||||
|
|
||||||
|
@ -756,12 +748,13 @@ class Dataset(_Dataset):
|
||||||
else:
|
else:
|
||||||
resps, metadata = _load_quants(path, return_metadata=True)
|
resps, metadata = _load_quants(path, return_metadata=True)
|
||||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||||
#text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
|
|
||||||
|
|
||||||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||||
|
|
||||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||||
"""
|
"""
|
||||||
|
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
|
||||||
|
# it probably is better to pad with silence instead of just stitching utterances and ruining things
|
||||||
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||||
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
|
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
|
||||||
|
|
||||||
|
@ -997,13 +990,6 @@ class Dataset(_Dataset):
|
||||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||||
return min(len(self.paths), self._head or len(self.paths))
|
return min(len(self.paths), self._head or len(self.paths))
|
||||||
|
|
||||||
def pin_memory(self):
|
|
||||||
self.text = self.text.pin_memory()
|
|
||||||
self.proms = self.proms.pin_memory()
|
|
||||||
self.resps = self.resps.pin_memory()
|
|
||||||
self.resp = self.resp.pin_memory()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(samples: list[dict]):
|
def collate_fn(samples: list[dict]):
|
||||||
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
||||||
|
@ -1017,14 +1003,8 @@ def _seed_worker(worker_id):
|
||||||
|
|
||||||
|
|
||||||
def _create_dataloader(dataset, training):
|
def _create_dataloader(dataset, training):
|
||||||
"""
|
|
||||||
if cfg.distributed and training:
|
|
||||||
sampler = DistributedSampler(dataset)
|
|
||||||
shuffle = False
|
|
||||||
"""
|
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
shuffle=dataset.shuffle,
|
shuffle=False,
|
||||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||||
drop_last=training,
|
drop_last=training,
|
||||||
sampler=dataset.sampler,
|
sampler=dataset.sampler,
|
||||||
|
@ -1037,7 +1017,7 @@ def _create_dataloader(dataset, training):
|
||||||
num_workers=cfg.dataset.workers,
|
num_workers=cfg.dataset.workers,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
persistent_workers=cfg.dataset.workers > 1,
|
persistent_workers=cfg.dataset.workers > 1,
|
||||||
pin_memory=False, # True,
|
pin_memory=False,
|
||||||
worker_init_fn=_seed_worker,
|
worker_init_fn=_seed_worker,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,14 +9,17 @@ from .distributed import global_rank, local_rank, world_size
|
||||||
|
|
||||||
# Randomly picks an index from an array of indices
|
# Randomly picks an index from an array of indices
|
||||||
class PoolSampler():
|
class PoolSampler():
|
||||||
def __init__( self, pool = [], keep_all = False ):
|
def __init__( self, pool = [], keep_all = False, shuffle = False ):
|
||||||
self.length = len(pool)
|
self.length = len(pool)
|
||||||
|
self.shuffle = shuffle
|
||||||
self.global_pool = pool if keep_all else None
|
self.global_pool = pool if keep_all else None
|
||||||
self.global_indices = [ i for i in range(self.length) ]
|
self.global_indices = [ i for i in range(self.length) ]
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.current_pool = [ i for i in self.global_indices ]
|
self.current_pool = [ i for i in self.global_indices ]
|
||||||
|
if self.shuffle:
|
||||||
|
random(self.current_pool)
|
||||||
|
|
||||||
def sample(self, pool = None):
|
def sample(self, pool = None):
|
||||||
if pool is None:
|
if pool is None:
|
||||||
|
@ -78,9 +81,10 @@ class OrderedSampler(Sampler):
|
||||||
|
|
||||||
# Like the above, but will batch based on token count
|
# Like the above, but will batch based on token count
|
||||||
class BatchedOrderedSampler(Sampler):
|
class BatchedOrderedSampler(Sampler):
|
||||||
def __init__( self, buckets, max_duration=0, max_batch_size=0 ):
|
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
|
||||||
self.position = 0
|
self.position = 0
|
||||||
self.batches = []
|
self.batches = []
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||||
|
|
||||||
|
@ -105,12 +109,17 @@ class BatchedOrderedSampler(Sampler):
|
||||||
current_index += 1
|
current_index += 1
|
||||||
current_size += duration
|
current_size += duration
|
||||||
|
|
||||||
|
if self.shuffle:
|
||||||
|
random.shuffle(self.batches)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.batches)
|
return len(self.batches)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self.position >= len(self.batches):
|
if self.position >= len(self.batches):
|
||||||
self.position = 0
|
self.position = 0
|
||||||
|
if self.shuffle:
|
||||||
|
random.shuffle(self.batches)
|
||||||
|
|
||||||
while self.position < len(self.batches):
|
while self.position < len(self.batches):
|
||||||
yield self.batches[self.position]
|
yield self.batches[self.position]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user