tweaks to bucket sampling
This commit is contained in:
parent
b2eca271a8
commit
29e45be0b4
|
@ -87,7 +87,7 @@ class OrderedSampler(Sampler):
|
|||
|
||||
# Like the above, but will batch based on token count
|
||||
class BatchedOrderedSampler(Sampler):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, drop_last=True, use_max_size=True ):
|
||||
self.position = 0
|
||||
self.batches = []
|
||||
self.shuffle = shuffle
|
||||
|
@ -95,13 +95,14 @@ class BatchedOrderedSampler(Sampler):
|
|||
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||
|
||||
current_batch = []
|
||||
current_size = 0
|
||||
current_index = 0
|
||||
current_duration = 0
|
||||
|
||||
for key, bucket in buckets.items():
|
||||
for path, duration in bucket:
|
||||
# flush
|
||||
should_flush = False
|
||||
if max_duration > 0 and current_size + duration > max_duration:
|
||||
if max_duration > 0 and current_duration + duration > max_duration:
|
||||
should_flush = True
|
||||
elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
|
||||
should_flush = True
|
||||
|
@ -109,11 +110,18 @@ class BatchedOrderedSampler(Sampler):
|
|||
if should_flush and len(current_batch) > 0:
|
||||
self.batches.append( current_batch )
|
||||
current_batch = []
|
||||
current_size = 0
|
||||
current_duration = 0
|
||||
|
||||
current_batch.append( current_index )
|
||||
current_index += 1
|
||||
current_size += duration
|
||||
# as long as durations are ordered, this assertion is always true
|
||||
if use_max_size:
|
||||
current_duration = duration * len(current_batch)
|
||||
else:
|
||||
current_duration += duration
|
||||
|
||||
if not drop_last and current_batch:
|
||||
self.batches.append( current_batch )
|
||||
|
||||
if self.shuffle:
|
||||
random.shuffle(self.batches)
|
||||
|
|
Loading…
Reference in New Issue
Block a user