backported additions from e-c-k-e-r/vall-e (paths sorted-by-duration and batched sampling)

This commit is contained in:
mrq 2024-06-28 22:29:42 -05:00
parent e0a93a6400
commit 43d85d97aa
3 changed files with 74 additions and 10 deletions

View File

@ -163,6 +163,8 @@ class Dataset:
sample_type: str = "path" # path | speaker
sample_order: str = "shuffle" # duration
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
tasks_list: list[str] = field(default_factory=lambda: ["tts"])

View File

@ -12,7 +12,7 @@ import itertools
from .config import cfg
from .emb.mel import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
from .utils.sampler import PoolSampler, OrderedSampler, RandomSampler
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from collections import defaultdict
@ -270,23 +270,29 @@ class Dataset(_Dataset):
if self.sampler_order != "duration":
continue
bucket = str(int(round(duration)))
bucket = int(round(duration))
if bucket not in self.duration_buckets:
self.duration_buckets[bucket] = []
self.duration_buckets[bucket].append( ( Path(path), duration ) )
# ensure they're ordered
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
# sort by duration
if self.sampler_order == "duration":
flattened = {}
# sort and interleave
for bucket in self.duration_buckets:
# sort by duration
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
# split to retain tuples
flattened[bucket] = self.duration_buckets[bucket]
# replace with path
self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ]
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
# flatten by paths
self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)]
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
# flatten paths
self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values()))
self.paths = list(itertools.chain.from_iterable(flattened.values()))
elif self.sampler_order == "shuffle":
# just interleave
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
@ -328,7 +334,10 @@ class Dataset(_Dataset):
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
if self.sampler_type == "path":
self.sampler = OrderedSampler( len(self) )
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size )
else:
self.sampler = OrderedSampler( len(self) )
self.samplers = {}
self.spkr_samplers = {}
else:
@ -564,17 +573,23 @@ def _create_dataloader(dataset, training):
shuffle = False
"""
kwargs = dict(
shuffle=dataset.shuffle,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
drop_last=training,
sampler=dataset.sampler,
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
batch_sampler=dataset.sampler,
)
return DataLoader(
dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=dataset.shuffle,
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 1,
pin_memory=False, # True,
worker_init_fn=_seed_worker,
sampler=dataset.sampler,
**kwargs,
)
def create_datasets():

View File

@ -74,6 +74,53 @@ class OrderedSampler(Sampler):
self.position = state["position"]
self.length = state["length"]
# Like the above, but will batch based on token count
class BatchedOrderedSampler(Sampler):
def __init__( self, buckets, max_duration=0, max_batch_size=0 ):
self.position = 0
self.batches = []
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
current_batch = []
current_size = 0
current_index = 0
for key, bucket in buckets.items():
for path, duration in bucket:
# flush
should_flush = False
if max_duration > 0 and current_size + duration > max_duration:
should_flush = True
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
should_flush = True
if should_flush and len(current_batch) > 0:
self.batches.append( current_batch )
current_batch = []
current_size = 0
current_batch.append( current_index )
current_index += 1
current_size += duration
def __len__(self):
return len(self.batches)
def __iter__(self):
if self.position >= len(self.batches):
self.position = 0
while self.position < len(self.batches):
yield self.batches[self.position]
self.position += 1
def get_state(self):
return { "position": self.position, "batches": self.batches }
def set_state(self, state):
self.position = state["position"]
self.batches = state["batches"]
# Randomly samples indices from a given sequence from 0 to length
# Allows saving and loading state
class RandomSampler(Sampler):