From 29e45be0b4952617c0c1518e9b7b7152b5fdf5fe Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 13 Nov 2024 11:09:24 -0600 Subject: [PATCH] tweaks to bucket sampling --- vall_e/utils/sampler.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py index 908c0b8..4060fcd 100644 --- a/vall_e/utils/sampler.py +++ b/vall_e/utils/sampler.py @@ -87,7 +87,7 @@ class OrderedSampler(Sampler): # Like the above, but will batch based on token count class BatchedOrderedSampler(Sampler): - def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ): + def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, drop_last=True, use_max_size=True ): self.position = 0 self.batches = [] self.shuffle = shuffle @@ -95,13 +95,14 @@ class BatchedOrderedSampler(Sampler): assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" current_batch = [] - current_size = 0 current_index = 0 + current_duration = 0 + for key, bucket in buckets.items(): for path, duration in bucket: # flush should_flush = False - if max_duration > 0 and current_size + duration > max_duration: + if max_duration > 0 and current_duration + duration > max_duration: should_flush = True elif max_batch_size > 0 and len(current_batch) >= max_batch_size: should_flush = True @@ -109,11 +110,18 @@ class BatchedOrderedSampler(Sampler): if should_flush and len(current_batch) > 0: self.batches.append( current_batch ) current_batch = [] - current_size = 0 + current_duration = 0 current_batch.append( current_index ) current_index += 1 - current_size += duration + # as long as durations are ordered, this assertion is always true + if use_max_size: + current_duration = duration * len(current_batch) + else: + current_duration += duration + + if not drop_last and current_batch: + self.batches.append( current_batch ) if self.shuffle: random.shuffle(self.batches)