torchscale/examples/fairseq/tasks/data/mlm_loader.py
2022-11-23 08:21:58 -08:00

308 lines
11 KiB
Python

import glob
import os
import torch
import numpy as np
import time
import json
import random
import itertools
import copy
from infinibatch import iterators
from .basic_loader import BaseBatchGen
from .utils import NativeCheckpointableIterator, WeightIterator
class MLMLoader(BaseBatchGen):
def __init__(
self,
args,
dataset,
dictionary,
tokenizer,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
):
super().__init__()
self.args = args
self.data = dataset.data
self.data_dir = dataset.data_dir
self.shuffle = dataset.shuffle
self.dictionary = dictionary
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.max_sentences = max_sentences
self.max_positions = max_positions
self.tokens_per_sample = args.tokens_per_sample
self.sample_break_mode = args.sample_break_mode
self.ignore_invalid_inputs = ignore_invalid_inputs
self.required_batch_size_multiple = required_batch_size_multiple
self.seed = str(seed)
self.num_shards = num_shards
self.shard_id = shard_id
self.batch_read_ahead = args.batch_read_ahead
self._build_iter()
def _build_iter(self):
tokenized_lines = self._multilingual_tokenize()
self.padded_batches = self._batchify(tokenized_lines)
prefetch_batches = iterators.PrefetchIterator(
self.padded_batches,
buffer_size=10000,
buffer_in_main_process=True,
log_empty_buffer_warning=True and self.shard_id == 0,
)
prefetch_batches = iterators.MapIterator(
prefetch_batches, self._move_to_tensor
)
self._iter = prefetch_batches
def _multilingual_tokenize(self):
multilingual_iters = []
weights = []
for data in self.data:
multilingual_iters.append(
self._tokenize(data)
)
if 'weight' in data:
weights.append(float(data['weight']))
else:
weights.append(int(data['count']))
if len(multilingual_iters) == 1:
return multilingual_iters[0]
sampling_iterator = WeightIterator(weights)
control_iterator = NativeCheckpointableIterator(sampling_iterator)
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
return tokenized_lines
def _tokenize(self, data):
'''
data:
{
'source': list[Path],
'source_lang': str,
'count': int,
'weight': float,
'name': str,
}
'''
dataset = list(
zip(
data['source'],
itertools.repeat(data['source_lang']),
)
)
if self.shuffle:
chunk_files = \
iterators.InfinitePermutationSourceIterator(
dataset,
seed=self.seed,
shuffle=self.shuffle,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
else:
chunk_files = \
iterators.ChunkedSourceIterator(
dataset,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
return tokenized_lines
def _batchify(self, lines):
if self.max_sentences is not None:
if self.batch_read_ahead > 0:
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
batches = iterators.FixedBatchIterator(lines, self.max_sentences)
else:
def dynamic_batch_size(sample):
lengths = [len(x) for x in sample]
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
return max(1, batch_size)
batches = iterators.BucketedReadaheadBatchIterator(
lines,
read_ahead=self.batch_read_ahead,
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
batch_size=dynamic_batch_size,
shuffle=self.shuffle,
seed=self.seed,
)
def collate(batch):
batch_size = len(batch)
mlm_source_max_length = max([len(x[0]) for x in batch])
mlm_target_max_length = max([len(x[1]) for x in batch])
s2s_source_max_length = max([len(x[2]) for x in batch])
s2s_target_max_length = max([len(x[3]) for x in batch])
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad())
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
ret_batch = {
'net_input': {
'src_tokens': mlm_source_ids.astype(np.int64),
},
'target': mlm_target_ids.astype(np.int64),
'nsentences': batch_size,
'ntokens': sum([len(x[0]) for x in batch]),
}
return ret_batch
padded_batches = iterators.MapIterator(
batches, collate
)
return padded_batches
def _prepare(self, _random, doc):
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
def _mask_lm(self, _random, doc):
def mask_tokens():
return f"<mask>"
length = len(doc)
mask_tokens_num = int(length * self.args.mask_prob)
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
possible_mask_positions = _random.sample(range(length), k=mask_tokens_num)
possible_mask_positions = sorted(possible_mask_positions)
nonmasked_tokens = copy.deepcopy(doc)
masked_tokens = [self.dictionary.pad() for _ in range(len(doc))]
for position in possible_mask_positions:
# masked_tokens.append(nonmasked_tokens[position])
masked_tokens[position] = nonmasked_tokens[position]
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
return nonmasked_tokens, masked_tokens
def _span_corruption(self, _random, doc):
def mask_tokens(i):
return f"<mask_{i}>"
length = len(doc)
noise_tokens_num = int(length * self.args.mask_prob)
noise_tokens_num = min(max(noise_tokens_num, 1), length - 1)
noise_spans_num = int(noise_tokens_num / self.args.span_length)
noise_spans_num = max(noise_spans_num, 1)
nonnoise_tokens_num = length - noise_tokens_num
if noise_spans_num == 1:
noise_split_positions = [0, noise_tokens_num]
else:
possible_split_positions = list(range(1, noise_tokens_num))
_random.shuffle(possible_split_positions)
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
possible_insert_positions = list(range(nonnoise_tokens_num))
_random.shuffle(possible_insert_positions)
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
nonnoise_spans, noise_spans = [], []
last_end = 0
for i in range(noise_spans_num):
start_pos = noise_insert_positions[i] + noise_split_positions[i]
end_pos = noise_insert_positions[i] + noise_split_positions[i+1]
mask_id = self.dictionary.indices[mask_tokens(i)]
if getattr(self.args, "remove_target_sentinel", False):
noise_spans.append(doc[start_pos:end_pos])
else:
noise_spans.append([mask_id] + doc[start_pos:end_pos])
if getattr(self.args, "remove_source_sentinel", False):
nonnoise_spans.extend(doc[last_end:start_pos])
else:
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
last_end = end_pos
nonnoise_spans.extend(doc[last_end:])
noise_spans = sum(noise_spans, [])
return nonnoise_spans, noise_spans
def _read_from_files(self, source_file, source_lang):
# data = []
file_path = os.path.join(self.data_dir, source_file)
if not os.path.exists(file_path):
print('| file {} not exists'.format(file_path), flush=True)
return iter([]) # skip bad file
with open(file_path, 'r', encoding='utf8') as f:
lines = f.read().strip().split('\n')
doc = [self.dictionary.bos()]
for line in lines:
if line == "":
if self.sample_break_mode == 'complete_doc':
# data.append(doc)
yield doc
doc = [self.dictionary.bos()]
continue
tokenized_line = self.tokenizer.EncodeAsPieces(line)
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
if len(tokenized_id) > self.tokens_per_sample:
continue
if len(doc) + len(tokenized_id) > self.tokens_per_sample:
# data.append(doc)
yield doc
doc = [self.dictionary.bos()]
doc.extend(tokenized_id)
if len(doc) > 1 and len(doc) <= self.tokens_per_sample:
# data.append(doc)
yield doc
# return data