torchscale/examples/fairseq/tasks/data/basic_loader.py

79 lines
1.9 KiB
Python
Raw Normal View History

2022-11-23 16:36:55 +00:00
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
2022-11-23 16:21:58 +00:00
import torch
from infinibatch.iterators import CheckpointableIterator
2022-11-26 17:01:02 +00:00
2022-11-23 16:21:58 +00:00
from . import utils
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
class BaseBatchGen(CheckpointableIterator):
"""
This is a base class for batch generators that use infinibatch
"""
def __init__(self):
self._iter = None
self.epoch = 1
self.next_epoch_idx = 1
self.sharded_checkpoint = True
self.should_close_after_finished = True
def _build_iter(self):
"""
Build infinibatch iterator and assign to self._iter
"""
raise NotImplementedError()
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def _move_to_tensor(self, batch):
def to_tensor(x):
return torch.tensor(x)
return utils.apply_to_sample(to_tensor, batch)
@property
def iterator(self):
if self._iter is None:
raise NotImplementedError("_build_iter() must called first")
return self._iter
def __iter__(self):
if self._iter is None:
raise NotImplementedError("_build_iter() must called first")
return self._iter
def __next__(self):
return next(self._iter)
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def setstate(self, value):
self._iter.setstate(value)
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def getstate(self):
return self._iter.getstate()
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def close(self):
self._iter.close()
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def __len__(self) -> int:
return 819200000
def next_epoch_itr(
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
):
return self
def end_of_epoch(self) -> bool:
return False
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
return self.getstate()
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
self.setstate(state_dict)
@property
def first_batch(self):
2022-11-26 16:10:15 +00:00
return "DUMMY"