# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import copy
import itertools
import os

import numpy as np
from infinibatch import iterators

from .basic_loader import BaseBatchGen
from .utils import NativeCheckpointableIterator, WeightIterator


class MLMLoader(BaseBatchGen):
    def __init__(
        self,
        args,
        dataset,
        dictionary,
        tokenizer,
        max_tokens=None,
        max_sentences=None,
        max_positions=None,
        ignore_invalid_inputs=False,
        required_batch_size_multiple=1,
        seed=1,
        num_shards=1,
        shard_id=0,
    ):
        super().__init__()
        self.args = args
        self.data = dataset.data
        self.data_dir = dataset.data_dir
        self.shuffle = dataset.shuffle
        self.dictionary = dictionary
        self.tokenizer = tokenizer

        self.max_tokens = max_tokens
        self.max_sentences = max_sentences
        self.max_positions = max_positions
        self.tokens_per_sample = args.tokens_per_sample
        self.sample_break_mode = args.sample_break_mode
        self.ignore_invalid_inputs = ignore_invalid_inputs
        self.required_batch_size_multiple = required_batch_size_multiple
        self.seed = str(seed)
        self.num_shards = num_shards
        self.shard_id = shard_id

        self.batch_read_ahead = args.batch_read_ahead

        self._build_iter()

    def _build_iter(self):
        tokenized_lines = self._multilingual_tokenize()
        self.padded_batches = self._batchify(tokenized_lines)

        prefetch_batches = iterators.PrefetchIterator(
            self.padded_batches,
            buffer_size=10000,
            buffer_in_main_process=True,
            log_empty_buffer_warning=True and self.shard_id == 0,
        )

        prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor)

        self._iter = prefetch_batches

    def _multilingual_tokenize(self):
        multilingual_iters = []
        weights = []

        for data in self.data:
            multilingual_iters.append(self._tokenize(data))
            if "weight" in data:
                weights.append(float(data["weight"]))
            else:
                weights.append(int(data["count"]))

        if len(multilingual_iters) == 1:
            return multilingual_iters[0]

        sampling_iterator = WeightIterator(weights)
        control_iterator = NativeCheckpointableIterator(sampling_iterator)
        tokenized_lines = iterators.MultiplexIterator(
            control_iterator, multilingual_iters
        )

        return tokenized_lines

    def _tokenize(self, data):
        """
        data:
        {
            'source': list[Path],
            'source_lang': str,
            'count': int,
            'weight': float,
            'name': str,
        }
        """
        dataset = list(
            zip(
                data["source"],
                itertools.repeat(data["source_lang"]),
            )
        )

        if self.shuffle:
            chunk_files = iterators.InfinitePermutationSourceIterator(
                dataset,
                seed=self.seed,
                shuffle=self.shuffle,
                num_instances=self.num_shards,
                instance_rank=self.shard_id,
            )
        else:
            chunk_files = iterators.ChunkedSourceIterator(
                dataset,
                num_instances=self.num_shards,
                instance_rank=self.shard_id,
            )

        tokenized_lines = iterators.SelectManyIterator(
            chunk_files, lambda files: self._read_from_files(*files)
        )
        tokenized_lines = iterators.SamplingRandomMapIterator(
            tokenized_lines, self._prepare, self.seed
        )

        return tokenized_lines

    def _batchify(self, lines):

        if self.max_sentences is not None:
            if self.batch_read_ahead > 0:
                lines = iterators.BlockwiseShuffleIterator(
                    lines, self.batch_read_ahead, self.seed
                )
            batches = iterators.FixedBatchIterator(lines, self.max_sentences)
        else:

            def dynamic_batch_size(sample):
                lengths = [len(x) for x in sample]
                batch_size = self.max_tokens // max(lengths)
                batch_size = (
                    batch_size
                    // self.required_batch_size_multiple
                    * self.required_batch_size_multiple
                )
                return max(1, batch_size)

            batches = iterators.BucketedReadaheadBatchIterator(
                lines,
                read_ahead=self.batch_read_ahead,
                key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
                batch_size=dynamic_batch_size,
                shuffle=self.shuffle,
                seed=self.seed,
            )

        def collate(batch):
            batch_size = len(batch)

            mlm_source_max_length = max([len(x[0]) for x in batch])
            mlm_target_max_length = max([len(x[1]) for x in batch])
            s2s_source_max_length = max([len(x[2]) for x in batch])
            s2s_target_max_length = max([len(x[3]) for x in batch])

            mlm_source_ids = np.full(
                shape=(batch_size, mlm_source_max_length),
                dtype=np.int32,
                fill_value=self.dictionary.pad(),
            )
            mlm_target_ids = np.full(
                shape=(batch_size, mlm_target_max_length),
                dtype=np.int32,
                fill_value=self.dictionary.pad(),
            )
            s2s_source_ids = np.full(
                shape=(batch_size, s2s_source_max_length),
                dtype=np.int32,
                fill_value=self.dictionary.pad(),
            )
            s2s_target_ids = np.full(
                shape=(batch_size, s2s_target_max_length - 1),
                dtype=np.int32,
                fill_value=self.dictionary.pad(),
            )
            s2s_prev_input_ids = np.full(
                shape=(batch_size, s2s_target_max_length - 1),
                dtype=np.int32,
                fill_value=self.dictionary.pad(),
            )

            for i, (
                mlm_input_ids,
                mlm_label_ids,
                s2s_input_ids,
                s2s_label_ids,
            ) in enumerate(batch):
                mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids
                mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids
                s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids
                s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:]
                s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1]

            ret_batch = {
                "net_input": {
                    "src_tokens": mlm_source_ids.astype(np.int64),
                },
                "target": mlm_target_ids.astype(np.int64),
                "nsentences": batch_size,
                "ntokens": sum([len(x[0]) for x in batch]),
            }

            return ret_batch

        padded_batches = iterators.MapIterator(batches, collate)

        return padded_batches

    def _prepare(self, _random, doc):
        nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
        nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
        return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans

    def _mask_lm(self, _random, doc):
        def mask_tokens():
            return "<mask>"

        length = len(doc)
        mask_tokens_num = int(length * self.args.mask_prob)
        mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
        possible_mask_positions = _random.sample(range(length), k=mask_tokens_num)
        possible_mask_positions = sorted(possible_mask_positions)

        nonmasked_tokens = copy.deepcopy(doc)
        masked_tokens = [self.dictionary.pad() for _ in range(len(doc))]

        for position in possible_mask_positions:
            # masked_tokens.append(nonmasked_tokens[position])
            masked_tokens[position] = nonmasked_tokens[position]
            nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]

        return nonmasked_tokens, masked_tokens

    def _span_corruption(self, _random, doc):
        def mask_tokens(i):
            return f"<mask_{i}>"

        length = len(doc)
        noise_tokens_num = int(length * self.args.mask_prob)
        noise_tokens_num = min(max(noise_tokens_num, 1), length - 1)
        noise_spans_num = int(noise_tokens_num / self.args.span_length)
        noise_spans_num = max(noise_spans_num, 1)
        nonnoise_tokens_num = length - noise_tokens_num

        if noise_spans_num == 1:
            noise_split_positions = [0, noise_tokens_num]
        else:
            possible_split_positions = list(range(1, noise_tokens_num))
            _random.shuffle(possible_split_positions)
            noise_split_positions = sorted(
                possible_split_positions[: noise_spans_num - 1]
            )
            noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]

        possible_insert_positions = list(range(nonnoise_tokens_num))
        _random.shuffle(possible_insert_positions)
        noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])

        nonnoise_spans, noise_spans = [], []
        last_end = 0
        for i in range(noise_spans_num):
            start_pos = noise_insert_positions[i] + noise_split_positions[i]
            end_pos = noise_insert_positions[i] + noise_split_positions[i + 1]
            mask_id = self.dictionary.indices[mask_tokens(i)]

            if getattr(self.args, "remove_target_sentinel", False):
                noise_spans.append(doc[start_pos:end_pos])
            else:
                noise_spans.append([mask_id] + doc[start_pos:end_pos])

            if getattr(self.args, "remove_source_sentinel", False):
                nonnoise_spans.extend(doc[last_end:start_pos])
            else:
                nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])

            last_end = end_pos

        nonnoise_spans.extend(doc[last_end:])
        noise_spans = sum(noise_spans, [])

        return nonnoise_spans, noise_spans

    def _read_from_files(self, source_file, source_lang):
        # data = []
        file_path = os.path.join(self.data_dir, source_file)

        if not os.path.exists(file_path):
            print("| file {} not exists".format(file_path), flush=True)
            return iter([])  # skip bad file

        with open(file_path, "r", encoding="utf8") as f:
            lines = f.read().strip().split("\n")

        doc = [self.dictionary.bos()]
        for line in lines:
            if line == "":
                if self.sample_break_mode == "complete_doc":
                    # data.append(doc)
                    yield doc
                    doc = [self.dictionary.bos()]
                continue

            tokenized_line = self.tokenizer.EncodeAsPieces(line)
            tokenized_id = [
                self.dictionary.index(token) for token in tokenized_line
            ] + [self.dictionary.eos_index]

            if len(tokenized_id) > self.tokens_per_sample:
                continue
            if len(doc) + len(tokenized_id) > self.tokens_per_sample:
                # data.append(doc)
                yield doc
                doc = [self.dictionary.bos()]
            doc.extend(tokenized_id)

        if len(doc) > 1 and len(doc) <= self.tokens_per_sample:
            # data.append(doc)
            yield doc

        # return data