# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import glob import os import torch import numpy as np import time import json import random import itertools import copy from infinibatch import iterators from .basic_loader import BaseBatchGen from .utils import NativeCheckpointableIterator, WeightIterator class MLMLoader(BaseBatchGen): def __init__( self, args, dataset, dictionary, tokenizer, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, ): super().__init__() self.args = args self.data = dataset.data self.data_dir = dataset.data_dir self.shuffle = dataset.shuffle self.dictionary = dictionary self.tokenizer = tokenizer self.max_tokens = max_tokens self.max_sentences = max_sentences self.max_positions = max_positions self.tokens_per_sample = args.tokens_per_sample self.sample_break_mode = args.sample_break_mode self.ignore_invalid_inputs = ignore_invalid_inputs self.required_batch_size_multiple = required_batch_size_multiple self.seed = str(seed) self.num_shards = num_shards self.shard_id = shard_id self.batch_read_ahead = args.batch_read_ahead self._build_iter() def _build_iter(self): tokenized_lines = self._multilingual_tokenize() self.padded_batches = self._batchify(tokenized_lines) prefetch_batches = iterators.PrefetchIterator( self.padded_batches, buffer_size=10000, buffer_in_main_process=True, log_empty_buffer_warning=True and self.shard_id == 0, ) prefetch_batches = iterators.MapIterator( prefetch_batches, self._move_to_tensor ) self._iter = prefetch_batches def _multilingual_tokenize(self): multilingual_iters = [] weights = [] for data in self.data: multilingual_iters.append( self._tokenize(data) ) if 'weight' in data: weights.append(float(data['weight'])) else: weights.append(int(data['count'])) if len(multilingual_iters) == 1: return multilingual_iters[0] sampling_iterator = WeightIterator(weights) control_iterator = NativeCheckpointableIterator(sampling_iterator) tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters) return tokenized_lines def _tokenize(self, data): ''' data: { 'source': list[Path], 'source_lang': str, 'count': int, 'weight': float, 'name': str, } ''' dataset = list( zip( data['source'], itertools.repeat(data['source_lang']), ) ) if self.shuffle: chunk_files = \ iterators.InfinitePermutationSourceIterator( dataset, seed=self.seed, shuffle=self.shuffle, num_instances=self.num_shards, instance_rank=self.shard_id, ) else: chunk_files = \ iterators.ChunkedSourceIterator( dataset, num_instances=self.num_shards, instance_rank=self.shard_id, ) tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files)) tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed) return tokenized_lines def _batchify(self, lines): if self.max_sentences is not None: if self.batch_read_ahead > 0: lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed) batches = iterators.FixedBatchIterator(lines, self.max_sentences) else: def dynamic_batch_size(sample): lengths = [len(x) for x in sample] batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple return max(1, batch_size) batches = iterators.BucketedReadaheadBatchIterator( lines, read_ahead=self.batch_read_ahead, key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None, batch_size=dynamic_batch_size, shuffle=self.shuffle, seed=self.seed, ) def collate(batch): batch_size = len(batch) mlm_source_max_length = max([len(x[0]) for x in batch]) mlm_target_max_length = max([len(x[1]) for x in batch]) s2s_source_max_length = max([len(x[2]) for x in batch]) s2s_target_max_length = max([len(x[3]) for x in batch]) mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32, fill_value=self.dictionary.pad()) mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32, fill_value=self.dictionary.pad()) s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32, fill_value=self.dictionary.pad()) s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, fill_value=self.dictionary.pad()) s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, fill_value=self.dictionary.pad()) for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch): mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:] s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1] ret_batch = { 'net_input': { 'src_tokens': mlm_source_ids.astype(np.int64), }, 'target': mlm_target_ids.astype(np.int64), 'nsentences': batch_size, 'ntokens': sum([len(x[0]) for x in batch]), } return ret_batch padded_batches = iterators.MapIterator( batches, collate ) return padded_batches def _prepare(self, _random, doc): nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc) nonnoise_spans, noise_spans = self._span_corruption(_random, doc) return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans def _mask_lm(self, _random, doc): def mask_tokens(): return f"" length = len(doc) mask_tokens_num = int(length * self.args.mask_prob) mask_tokens_num = min(max(mask_tokens_num, 1), length - 1) possible_mask_positions = _random.sample(range(length), k=mask_tokens_num) possible_mask_positions = sorted(possible_mask_positions) nonmasked_tokens = copy.deepcopy(doc) masked_tokens = [self.dictionary.pad() for _ in range(len(doc))] for position in possible_mask_positions: # masked_tokens.append(nonmasked_tokens[position]) masked_tokens[position] = nonmasked_tokens[position] nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()] return nonmasked_tokens, masked_tokens def _span_corruption(self, _random, doc): def mask_tokens(i): return f"" length = len(doc) noise_tokens_num = int(length * self.args.mask_prob) noise_tokens_num = min(max(noise_tokens_num, 1), length - 1) noise_spans_num = int(noise_tokens_num / self.args.span_length) noise_spans_num = max(noise_spans_num, 1) nonnoise_tokens_num = length - noise_tokens_num if noise_spans_num == 1: noise_split_positions = [0, noise_tokens_num] else: possible_split_positions = list(range(1, noise_tokens_num)) _random.shuffle(possible_split_positions) noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1]) noise_split_positions = [0] + noise_split_positions + [noise_tokens_num] possible_insert_positions = list(range(nonnoise_tokens_num)) _random.shuffle(possible_insert_positions) noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num]) nonnoise_spans, noise_spans = [], [] last_end = 0 for i in range(noise_spans_num): start_pos = noise_insert_positions[i] + noise_split_positions[i] end_pos = noise_insert_positions[i] + noise_split_positions[i+1] mask_id = self.dictionary.indices[mask_tokens(i)] if getattr(self.args, "remove_target_sentinel", False): noise_spans.append(doc[start_pos:end_pos]) else: noise_spans.append([mask_id] + doc[start_pos:end_pos]) if getattr(self.args, "remove_source_sentinel", False): nonnoise_spans.extend(doc[last_end:start_pos]) else: nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id]) last_end = end_pos nonnoise_spans.extend(doc[last_end:]) noise_spans = sum(noise_spans, []) return nonnoise_spans, noise_spans def _read_from_files(self, source_file, source_lang): # data = [] file_path = os.path.join(self.data_dir, source_file) if not os.path.exists(file_path): print('| file {} not exists'.format(file_path), flush=True) return iter([]) # skip bad file with open(file_path, 'r', encoding='utf8') as f: lines = f.read().strip().split('\n') doc = [self.dictionary.bos()] for line in lines: if line == "": if self.sample_break_mode == 'complete_doc': # data.append(doc) yield doc doc = [self.dictionary.bos()] continue tokenized_line = self.tokenizer.EncodeAsPieces(line) tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index] if len(tokenized_id) > self.tokens_per_sample: continue if len(doc) + len(tokenized_id) > self.tokens_per_sample: # data.append(doc) yield doc doc = [self.dictionary.bos()] doc.extend(tokenized_id) if len(doc) > 1 and len(doc) <= self.tokens_per_sample: # data.append(doc) yield doc # return data