From 43d85d97aaf36cbba34fa69e11bbc0d36b4b246b Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 28 Jun 2024 22:29:42 -0500 Subject: [PATCH] backported additions from e-c-k-e-r/vall-e (paths sorted-by-duration and batched sampling) --- tortoise_tts/config.py | 2 ++ tortoise_tts/data.py | 35 ++++++++++++++++++-------- tortoise_tts/utils/sampler.py | 47 +++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 10 deletions(-) diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 18f10ca..305022a 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -163,6 +163,8 @@ class Dataset: sample_type: str = "path" # path | speaker sample_order: str = "shuffle" # duration + sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable + tasks_list: list[str] = field(default_factory=lambda: ["tts"]) diff --git a/tortoise_tts/data.py b/tortoise_tts/data.py index 130a3fa..1e58c65 100755 --- a/tortoise_tts/data.py +++ b/tortoise_tts/data.py @@ -12,7 +12,7 @@ import itertools from .config import cfg from .emb.mel import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file -from .utils.sampler import PoolSampler, OrderedSampler, RandomSampler +from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size from collections import defaultdict @@ -270,23 +270,29 @@ class Dataset(_Dataset): if self.sampler_order != "duration": continue - bucket = str(int(round(duration))) + bucket = int(round(duration)) if bucket not in self.duration_buckets: self.duration_buckets[bucket] = [] self.duration_buckets[bucket].append( ( Path(path), duration ) ) + # ensure they're ordered + self.duration_buckets = dict(sorted(self.duration_buckets.items())) + # sort by duration if self.sampler_order == "duration": + flattened = {} # sort and interleave for bucket in self.duration_buckets: # sort by duration self.duration_buckets[bucket].sort( key=lambda x: x[1] ) + # split to retain tuples + flattened[bucket] = self.duration_buckets[bucket] # replace with path - self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ] + flattened[bucket] = [ x[0] for x in flattened[bucket] ] # flatten by paths - self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)] + flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)] # flatten paths - self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values())) + self.paths = list(itertools.chain.from_iterable(flattened.values())) elif self.sampler_order == "shuffle": # just interleave self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] @@ -328,7 +334,10 @@ class Dataset(_Dataset): sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" if self.sampler_type == "path": - self.sampler = OrderedSampler( len(self) ) + if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: + self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size ) + else: + self.sampler = OrderedSampler( len(self) ) self.samplers = {} self.spkr_samplers = {} else: @@ -564,17 +573,23 @@ def _create_dataloader(dataset, training): shuffle = False """ + kwargs = dict( + shuffle=dataset.shuffle, + batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, + drop_last=training, + sampler=dataset.sampler, + ) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict( + batch_sampler=dataset.sampler, + ) + return DataLoader( dataset=dataset, - batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, - shuffle=dataset.shuffle, - drop_last=training, num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=cfg.dataset.workers > 1, pin_memory=False, # True, worker_init_fn=_seed_worker, - sampler=dataset.sampler, + **kwargs, ) def create_datasets(): diff --git a/tortoise_tts/utils/sampler.py b/tortoise_tts/utils/sampler.py index b584bd5..d9ebe37 100644 --- a/tortoise_tts/utils/sampler.py +++ b/tortoise_tts/utils/sampler.py @@ -74,6 +74,53 @@ class OrderedSampler(Sampler): self.position = state["position"] self.length = state["length"] +# Like the above, but will batch based on token count +class BatchedOrderedSampler(Sampler): + def __init__( self, buckets, max_duration=0, max_batch_size=0 ): + self.position = 0 + self.batches = [] + + assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" + + current_batch = [] + current_size = 0 + current_index = 0 + for key, bucket in buckets.items(): + for path, duration in bucket: + # flush + should_flush = False + if max_duration > 0 and current_size + duration > max_duration: + should_flush = True + elif max_batch_size > 0 and len(current_batch) >= max_batch_size: + should_flush = True + + if should_flush and len(current_batch) > 0: + self.batches.append( current_batch ) + current_batch = [] + current_size = 0 + + current_batch.append( current_index ) + current_index += 1 + current_size += duration + + def __len__(self): + return len(self.batches) + + def __iter__(self): + if self.position >= len(self.batches): + self.position = 0 + + while self.position < len(self.batches): + yield self.batches[self.position] + self.position += 1 + + def get_state(self): + return { "position": self.position, "batches": self.batches } + + def set_state(self, state): + self.position = state["position"] + self.batches = state["batches"] + # Randomly samples indices from a given sequence from 0 to length # Allows saving and loading state class RandomSampler(Sampler):