torchscale/examples/fairseq/tasks/data/basic_loader.py
2022-11-23 08:36:55 -08:00

81 lines
2.0 KiB
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import re
import sys
import time
import torch
from infinibatch.iterators import CheckpointableIterator
from . import utils
class BaseBatchGen(CheckpointableIterator):
"""
This is a base class for batch generators that use infinibatch
"""
def __init__(self):
self._iter = None
self.epoch = 1
self.next_epoch_idx = 1
self.sharded_checkpoint = True
self.should_close_after_finished = True
def _build_iter(self):
"""
Build infinibatch iterator and assign to self._iter
"""
raise NotImplementedError()
def _move_to_tensor(self, batch):
def to_tensor(x):
return torch.tensor(x)
return utils.apply_to_sample(to_tensor, batch)
@property
def iterator(self):
if self._iter is None:
raise NotImplementedError("_build_iter() must called first")
return self._iter
def __iter__(self):
if self._iter is None:
raise NotImplementedError("_build_iter() must called first")
return self._iter
def __next__(self):
return next(self._iter)
def setstate(self, value):
self._iter.setstate(value)
def getstate(self):
return self._iter.getstate()
def close(self):
self._iter.close()
def __len__(self) -> int:
return 819200000
def next_epoch_itr(
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
):
return self
def end_of_epoch(self) -> bool:
return False
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
return self.getstate()
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
self.setstate(state_dict)
@property
def first_batch(self):
return "DUMMY"