backported additions from e-c-k-e-r/vall-e (paths sorted-by-duration and batched sampling)
This commit is contained in:
parent
e0a93a6400
commit
43d85d97aa
|
@ -163,6 +163,8 @@ class Dataset:
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
sample_order: str = "shuffle" # duration
|
sample_order: str = "shuffle" # duration
|
||||||
|
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||||
|
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import itertools
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.mel import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
from .emb.mel import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||||
from .utils.sampler import PoolSampler, OrderedSampler, RandomSampler
|
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||||
from .utils.distributed import global_rank, local_rank, world_size
|
from .utils.distributed import global_rank, local_rank, world_size
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
@ -270,23 +270,29 @@ class Dataset(_Dataset):
|
||||||
if self.sampler_order != "duration":
|
if self.sampler_order != "duration":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
bucket = str(int(round(duration)))
|
bucket = int(round(duration))
|
||||||
if bucket not in self.duration_buckets:
|
if bucket not in self.duration_buckets:
|
||||||
self.duration_buckets[bucket] = []
|
self.duration_buckets[bucket] = []
|
||||||
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
||||||
|
|
||||||
|
# ensure they're ordered
|
||||||
|
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
|
||||||
|
|
||||||
# sort by duration
|
# sort by duration
|
||||||
if self.sampler_order == "duration":
|
if self.sampler_order == "duration":
|
||||||
|
flattened = {}
|
||||||
# sort and interleave
|
# sort and interleave
|
||||||
for bucket in self.duration_buckets:
|
for bucket in self.duration_buckets:
|
||||||
# sort by duration
|
# sort by duration
|
||||||
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
||||||
|
# split to retain tuples
|
||||||
|
flattened[bucket] = self.duration_buckets[bucket]
|
||||||
# replace with path
|
# replace with path
|
||||||
self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ]
|
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||||
# flatten by paths
|
# flatten by paths
|
||||||
self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)]
|
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
||||||
# flatten paths
|
# flatten paths
|
||||||
self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values()))
|
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
||||||
elif self.sampler_order == "shuffle":
|
elif self.sampler_order == "shuffle":
|
||||||
# just interleave
|
# just interleave
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
@ -328,6 +334,9 @@ class Dataset(_Dataset):
|
||||||
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
|
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||||
|
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size )
|
||||||
|
else:
|
||||||
self.sampler = OrderedSampler( len(self) )
|
self.sampler = OrderedSampler( len(self) )
|
||||||
self.samplers = {}
|
self.samplers = {}
|
||||||
self.spkr_samplers = {}
|
self.spkr_samplers = {}
|
||||||
|
@ -564,17 +573,23 @@ def _create_dataloader(dataset, training):
|
||||||
shuffle = False
|
shuffle = False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
shuffle=dataset.shuffle,
|
||||||
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||||
|
drop_last=training,
|
||||||
|
sampler=dataset.sampler,
|
||||||
|
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
||||||
|
batch_sampler=dataset.sampler,
|
||||||
|
)
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
|
||||||
shuffle=dataset.shuffle,
|
|
||||||
drop_last=training,
|
|
||||||
num_workers=cfg.dataset.workers,
|
num_workers=cfg.dataset.workers,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
persistent_workers=cfg.dataset.workers > 1,
|
persistent_workers=cfg.dataset.workers > 1,
|
||||||
pin_memory=False, # True,
|
pin_memory=False, # True,
|
||||||
worker_init_fn=_seed_worker,
|
worker_init_fn=_seed_worker,
|
||||||
sampler=dataset.sampler,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_datasets():
|
def create_datasets():
|
||||||
|
|
|
@ -74,6 +74,53 @@ class OrderedSampler(Sampler):
|
||||||
self.position = state["position"]
|
self.position = state["position"]
|
||||||
self.length = state["length"]
|
self.length = state["length"]
|
||||||
|
|
||||||
|
# Like the above, but will batch based on token count
|
||||||
|
class BatchedOrderedSampler(Sampler):
|
||||||
|
def __init__( self, buckets, max_duration=0, max_batch_size=0 ):
|
||||||
|
self.position = 0
|
||||||
|
self.batches = []
|
||||||
|
|
||||||
|
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||||
|
|
||||||
|
current_batch = []
|
||||||
|
current_size = 0
|
||||||
|
current_index = 0
|
||||||
|
for key, bucket in buckets.items():
|
||||||
|
for path, duration in bucket:
|
||||||
|
# flush
|
||||||
|
should_flush = False
|
||||||
|
if max_duration > 0 and current_size + duration > max_duration:
|
||||||
|
should_flush = True
|
||||||
|
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
|
||||||
|
should_flush = True
|
||||||
|
|
||||||
|
if should_flush and len(current_batch) > 0:
|
||||||
|
self.batches.append( current_batch )
|
||||||
|
current_batch = []
|
||||||
|
current_size = 0
|
||||||
|
|
||||||
|
current_batch.append( current_index )
|
||||||
|
current_index += 1
|
||||||
|
current_size += duration
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batches)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.position >= len(self.batches):
|
||||||
|
self.position = 0
|
||||||
|
|
||||||
|
while self.position < len(self.batches):
|
||||||
|
yield self.batches[self.position]
|
||||||
|
self.position += 1
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return { "position": self.position, "batches": self.batches }
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
self.position = state["position"]
|
||||||
|
self.batches = state["batches"]
|
||||||
|
|
||||||
# Randomly samples indices from a given sequence from 0 to length
|
# Randomly samples indices from a given sequence from 0 to length
|
||||||
# Allows saving and loading state
|
# Allows saving and loading state
|
||||||
class RandomSampler(Sampler):
|
class RandomSampler(Sampler):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user