from dataclasses import dataclass from typing import Any import random import torch from torch.utils.data import Sampler from .distributed import global_rank, local_rank, world_size # Randomly picks an index from an array of indices class PoolSampler(): def __init__( self, pool = [], keep_all = False, shuffle = False ): self.length = len(pool) self.shuffle = shuffle self.global_pool = pool if keep_all else None self.global_indices = [ i for i in range(self.length) ] self.reset() def reset(self): self.current_pool = [ i for i in self.global_indices ] if self.shuffle: random.shuffle(self.current_pool) def sample(self, pool = None): if pool is None: pool = self.global_pool # check if we need to reset index = random.choice( self.current_pool ) # remove from pool self.current_pool.remove(index) # reset if needed if len(self.current_pool) == 0: self.reset() # map indices to our real values return pool[index] if pool is not None else index def __len__(self): return self.length # len(self.current_pool) def __iter__(self): while len(self.current_pool) > 0: yield self.sample() def __call__(self, *args, **kwargs): return self.sample(*args, **kwargs) def get_state(self): return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool } def set_state(self, state): self.length = state["length"] self.global_pool = state["global_pool"] self.global_indices = state["global_indices"] self.current_pool = state["current_pool"] # "Samples" through a fixed sequence from 0 to length # Necessary for our "shuffle+sort by duration+interleave" sampling method # Allows saving and loading state class OrderedSampler(Sampler): def __init__( self, length ): self.position = 0 self.length = length def __len__(self): return self.length def __iter__(self): if self.position >= self.length: self.position = 0 while self.position < self.length: yield self.position self.position += 1 def get_state(self): return { "position": self.position, "length": self.length } def set_state(self, state): self.position = state["position"] self.length = state["length"] # Like the above, but will batch based on token count class BatchedOrderedSampler(Sampler): def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ): self.position = 0 self.batches = [] self.shuffle = shuffle assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" current_batch = [] current_size = 0 current_index = 0 for key, bucket in buckets.items(): for path, duration in bucket: # flush should_flush = False if max_duration > 0 and current_size + duration > max_duration: should_flush = True elif max_batch_size > 0 and len(current_batch) >= max_batch_size: should_flush = True if should_flush and len(current_batch) > 0: self.batches.append( current_batch ) current_batch = [] current_size = 0 current_batch.append( current_index ) current_index += 1 current_size += duration if self.shuffle: random.shuffle(self.batches) def __len__(self): return len(self.batches) def __iter__(self): if self.position >= len(self.batches): self.position = 0 if self.shuffle: random.shuffle(self.batches) while self.position < len(self.batches): yield self.batches[self.position] self.position += 1 def get_state(self): return { "position": self.position, "batches": self.batches } def set_state(self, state): self.position = state["position"] self.batches = state["batches"] # Randomly samples indices from a given sequence from 0 to length # Allows saving and loading state class RandomSampler(Sampler): def __init__( self, length ): self.position = 0 self.length = length self.generator = torch.Generator() self.perm = torch.randperm(self.length, generator=self.generator) def __len__(self): return self.length def __iter__(self): if self.position >= self.length: self.position = 0 self.perm = torch.randperm(self.length, generator=self.generator) while self.position < self.length: yield self.perm[self.position] self.position += 1 def get_state(self): return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() } def set_state(self, state): self.position = state["position"] self.length = state["length"] self.perm = state["perm"] self.generator.set_state(state["generator"])