tweaks to bucket sampling

This commit is contained in:
mrq 2024-11-13 11:09:24 -06:00
parent b2eca271a8
commit 29e45be0b4

View File

@ -87,7 +87,7 @@ class OrderedSampler(Sampler):
# Like the above, but will batch based on token count
class BatchedOrderedSampler(Sampler):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, drop_last=True, use_max_size=True ):
self.position = 0
self.batches = []
self.shuffle = shuffle
@ -95,13 +95,14 @@ class BatchedOrderedSampler(Sampler):
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
current_batch = []
current_size = 0
current_index = 0
current_duration = 0
for key, bucket in buckets.items():
for path, duration in bucket:
# flush
should_flush = False
if max_duration > 0 and current_size + duration > max_duration:
if max_duration > 0 and current_duration + duration > max_duration:
should_flush = True
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
should_flush = True
@ -109,11 +110,18 @@ class BatchedOrderedSampler(Sampler):
if should_flush and len(current_batch) > 0:
self.batches.append( current_batch )
current_batch = []
current_size = 0
current_duration = 0
current_batch.append( current_index )
current_index += 1
current_size += duration
# as long as durations are ordered, this assertion is always true
if use_max_size:
current_duration = duration * len(current_batch)
else:
current_duration += duration
if not drop_last and current_batch:
self.batches.append( current_batch )
if self.shuffle:
random.shuffle(self.batches)