import sys
from os import path

from preprocess.lmdb_readers import LmdbReader

sys.path.append(
    path.dirname(path.abspath(path.dirname(path.abspath(path.dirname(__file__)))))
)

import os.path as op
import random, json
import torch
import _pickle as cPickle
from utils.misc import (
        load_from_yaml_file, find_file_path_in_yaml)

def build_dataset(yaml_file, tokenizer, args):
    dataset_class = CaptionLmdbDataset
    return dataset_class(
        yaml_file=yaml_file,
        tokenizer=tokenizer,
        add_od_labels=args.add_od_labels,
        max_img_seq_length=args.max_img_seq_length,
        max_seq_length=args.max_seq_length,
        max_seq_a_length=args.max_seq_a_length,
        MLM_train=args.MLM_train,
        mask_prob=args.mask_prob,
        max_masked_tokens=args.max_masked_tokens,
        args=args
    )

class CaptionLmdbDataset(LmdbReader):
    def __init__(
            self,
            yaml_file,
            tokenizer=None,
            add_od_labels=True,
            max_img_seq_length=50,
            max_seq_length=70,
            max_seq_a_length=40,
            MLM_train=False,
            mask_prob=0.15,
            max_masked_tokens=3,
            args=None,
            **kwargs
    ):
        """Constructor.
        Args:
            yaml file with all required data (image feature, caption, labels, etc)
            tokenizer: tokenizer for text processing.
            add_od_labels: whether to add labels from yaml file to BERT.
            max_img_seq_length: max image sequence length.
            max_seq_length: max text sequence length.
            max_seq_a_length: max caption sequence length.
            MLM_train: train MLM or QA.
            mask_prob: probability to mask a input token.
            max_masked_tokens: maximum number of tokens to be masked in one sentence.
            kwargs: other arguments.
        """
        self.yaml_file = yaml_file
        self.cfg = load_from_yaml_file(yaml_file)
        self.root = args.data_dir
        self.feat_file  = find_file_path_in_yaml(self.cfg['feature'], self.root)
        self.question_file = find_file_path_in_yaml(self.cfg['question'], self.root)

        assert op.isfile(self.feat_file)
        assert op.isfile(self.question_file)

        super(CaptionLmdbDataset, self).__init__(lmdb_path=self.feat_file)

        # test-dev and test does not have answer file.
        if self.cfg['answer'] is not None:
            self.answer_file = find_file_path_in_yaml(self.cfg['answer'], self.root)
            assert op.isfile(self.answer_file)

            self.answer = []
            if self.answer_file and op.isfile(self.answer_file):
                with open(self.answer_file, 'r') as f:
                    self.answer = json.load(f)["annotations"]

        self.questions = []
        if self.question_file and op.isfile(self.question_file):
            with open(self.question_file, 'r') as f:
                self.questions = json.load(f)["questions"]

        self.tokenizer = tokenizer
        self.tensorizer = CaptionTensorizer(self.tokenizer, max_img_seq_length,
                max_seq_length, max_seq_a_length, mask_prob, max_masked_tokens,
                MLM_train=MLM_train, args=args)
        self.add_od_labels = add_od_labels
        self.MLM_train = MLM_train
        self.kwargs = kwargs

        self.label2ans = cPickle.load(open(args.label2ans_file, 'rb'))
        self.ans2label = cPickle.load(open(args.ans2label_file, 'rb'))

        #######
        self.prefix_token = args.add_prefix
        self.num_prefix = args.num_prefix
        self.tag_from_set = False
        if self.prefix_token is False and args.tag_entire_set is not None:
            self.tag_from_set = True
            self.tag_num = max_seq_length - max_seq_a_length
            self.tag_set = json.load(open(args.tag_entire_set))
        #######

    def get_tag_set(self):
        od_labels = None
        if self.add_od_labels:

            od_labels = [
                self.tag_set[random.randint(0, self.tag_set.__len__() - 1)] for _ in range(self.tag_num)
            ]
            od_labels = ' '.join(map(str, od_labels))

        return od_labels

    def get_label(self, answer_list:[dict]):
        answers = []
        scores = {}
        label = [0]*len(self.ans2label)

        for ans in answer_list:
            answers.append(ans["answer"])

        for ans in set(answers):
            # Based on the VQA dataset provider's Evaluation method
            scores = {ans: min(answers.count(ans)/3, 1.0)}
            if ans not in self.ans2label.keys():
                continue

            label[self.ans2label[ans]] = scores[ans]

        return scores, label

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = self.questions[idx]

        img_id = question["image_id"]
        feature = self.get_image_feature(f"{img_id}")

        if self.tag_from_set is False:
            if self.prefix_token is True:
                od_labels = "[prefix]" * self.num_prefix # Inter Modality Token
            else:
                if self.add_od_labels is True:
                    od_labels = ' '.join(feature["class"]) # Tag Original
                else:
                    od_labels = ' '
        else:
            od_labels = self.get_tag_set() # Tag from Train/Val Tag Sets

        example = self.tensorizer.tensorize_example(question["question"],
                                                    torch.tensor(feature["feature"]),
                                                    text_b=od_labels)

        # For train and valid sets
        if self.cfg['answer'] is not None:
            answers = self.answer[idx]

            # check if the answer and question matches
            assert question["question_id"] == answers["question_id"]

            # get the label of the answers
            scores, label = self.get_label(answers["answers"])
            example.update({"labels":torch.tensor(label, dtype=torch.float)}) # update labels to the output

        return question["question_id"], example

class CaptionTensorizer(object):
    def __init__(
            self,
            tokenizer,
            max_img_seq_length=50,
            max_seq_length=70,
            max_seq_a_length=40,
            mask_prob=0.15,
            max_masked_tokens=3,
            MLM_train=False,
            args=None
    ):
        """Constructor.
        Args:
            tokenizer: tokenizer for text processing.
            max_img_seq_length: max image sequence length.
            max_seq_length: max text sequence length.
            max_seq_a_length: max caption sequence length.
            MLM_train: train MLM or QA.
            mask_prob: probability to mask a input token.
            max_masked_tokens: maximum number of tokens to be masked in one sentence.
        """
        self.tokenizer = tokenizer
        self.MLM_train = MLM_train
        self.max_img_seq_len = max_img_seq_length
        self.max_seq_len = max_seq_length
        self.max_seq_a_len = max_seq_a_length
        self.mask_prob = mask_prob
        self.max_masked_tokens = max_masked_tokens
        self._triangle_mask = torch.tril(torch.ones((self.max_seq_len,
            self.max_seq_len), dtype=torch.long))

        self.mask_c_t = args.mask_c_t
        self.mask_c_i = args.mask_c_i
        self.mask_t_t = args.mask_t_t
        self.mask_t_i = args.mask_t_i
        self.mask_i_t = args.mask_i_t
        self.mask_i_i = args.mask_i_t

        self.mask_inter_prefix = args.mask_inter_prefix if hasattr(args,'mask_inter_prefix') else False

        self.un_mask_t_c = args.un_mask_t_c
        self.un_mask_i_c = args.un_mask_i_c

        self.tag_add_sep = args.tag_sep_token

    def tensorize_example(
            self,
            text_a,
            img_feat,
            text_b=None,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            sequence_a_segment_id=0,
            sequence_b_segment_id=1
    ):

        tokens_a = self.tokenizer.tokenize(text_a)

        if len(tokens_a) > self.max_seq_a_len - 2:
            tokens_a = tokens_a[:(self.max_seq_a_len - 2)]

        tokens = [self.tokenizer.cls_token] + tokens_a
        tokens += [self.tokenizer.sep_token]
        seq_a_len = len(tokens)

        segment_ids = [cls_token_segment_id] + [sequence_a_segment_id] * (seq_a_len - 1)
        if text_b: # image object tags
            # pad text_a to keep it in fixed length for better inference.
            padding_a_len = self.max_seq_a_len - seq_a_len
            tokens += [self.tokenizer.pad_token] * padding_a_len

            segment_ids += ([pad_token_segment_id] * padding_a_len)

            tokens_b = self.tokenizer.tokenize(text_b)
            if len(tokens_b) > self.max_seq_len - len(tokens):
                tokens_b = tokens_b[: (self.max_seq_len - len(tokens))]

            if self.tag_add_sep is True:
                tokens += tokens_b[: -1] + [self.tokenizer.sep_token]
            else:
                tokens += tokens_b

            segment_ids += [sequence_b_segment_id] * (len(tokens_b))  # 0: caption, 1: tag

        # For MLM training
        if self.MLM_train:
            masked_pos = torch.zeros(self.max_seq_len, dtype=torch.int)
            # randomly mask words for prediction, ignore [CLS]
            candidate_masked_idx = list(range(1, seq_a_len)) # only mask text_a
            random.shuffle(candidate_masked_idx)
            num_masked = min(max(round(self.mask_prob * seq_a_len), 1), self.max_masked_tokens)
            num_masked = int(num_masked)

            masked_idx = candidate_masked_idx[:num_masked]
            masked_idx = sorted(masked_idx)
            masked_token = [tokens[i] for i in masked_idx]

            real_mask_pos = []
            for i, pos in enumerate(masked_idx):
                if random.random() <= 0.8:
                    # 80% chance to be a ['MASK'] token
                    tokens[pos] = self.tokenizer.mask_token
                    real_mask_pos += [pos]
                elif random.random() <= 0.5:
                    # 10% chance to be a random word ((1-0.8)*0.5)
                    from random import randint
                    # - random number start from 999. ids less than 999 are unused or special token
                    i = randint(999, len(self.tokenizer.vocab))
                    self.tokenizer._convert_id_to_token(i)
                    tokens[pos] = self.tokenizer._convert_id_to_token(i)
                else:
                    # 10% chance to remain the same (1-0.8-0.1)
                    pass

                # Training with unmask t/c
                if self.un_mask_t_c and len(real_mask_pos) > 0:
                    num_masked = i + 1
                    masked_idx = masked_idx[:num_masked]
                    masked_token = masked_token[:num_masked]
                    num_paded = seq_a_len - (real_mask_pos[0] + 1)
                    tokens[real_mask_pos[0] + 1:seq_a_len] = [self.tokenizer.pad_token] * num_paded
                    seq_a_len = real_mask_pos[0] + 1
                    break  # Don't mask: Only 1 mask for un mask t/c.

            masked_pos[masked_idx] = 1
            # pad masked tokens to the same length
            if num_masked < self.max_masked_tokens:
                masked_token = masked_token + ([self.tokenizer.pad_token] *
                        (self.max_masked_tokens - num_masked))
            masked_ids = self.tokenizer.convert_tokens_to_ids(masked_token)
        else:
            masked_pos = torch.zeros(self.max_seq_len, dtype=torch.int)

        # pad on the right for image captioning
        seq_len = len(tokens)
        padding_len = self.max_seq_len - seq_len
        tokens = tokens + ([self.tokenizer.pad_token] * padding_len)
        segment_ids += ([pad_token_segment_id] * padding_len)
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)

        # image features
        img_len = img_feat.shape[0]
        if img_len > self.max_img_seq_len:
            img_feat = img_feat[0:self.max_img_seq_len, ]
            img_len = img_feat.shape[0]
        else:
            padding_matrix = torch.zeros((self.max_img_seq_len - img_len,
                                          img_feat.shape[1]))
            img_feat = torch.cat((img_feat, padding_matrix), 0)

        # prepare attention mask:
        # note that there is no attention from caption to image
        # because otherwise it will violate the triangle attention
        # for caption as caption will have full attention on image.
        max_len = self.max_seq_len + self.max_img_seq_len
        attention_mask = torch.zeros((max_len, max_len), dtype=torch.long)
        # C: caption, L: label, R: image region
        c_start, c_end = 0, seq_a_len # caption
        l_start, l_end = self.max_seq_a_len, seq_len # tag
        r_start, r_end = self.max_seq_len, self.max_seq_len + img_len # image

        # triangle mask for caption to caption
        # attention_mask[c_start : c_end, c_start : c_end].copy_(self._triangle_mask[0 : seq_a_len, 0 : seq_a_len])
        attention_mask[c_start : c_end, c_start : c_end] = 1

        # full attention for caption-tag, caption-image
        attention_mask[c_start : c_end, l_start : l_end] = 1
        attention_mask[c_start : c_end, r_start : r_end] = 1

        # full attention for tag-caption tag-tag, tag-image
        attention_mask[l_start : l_end, c_start : c_end] = 1
        attention_mask[l_start : l_end, l_start : l_end] = 1
        attention_mask[l_start : l_end, r_start : r_end] = 1

        # full attention for image-caption image-tag, image-image
        attention_mask[r_start : r_end, c_start : c_end] = 1
        attention_mask[r_start : r_end, l_start : l_end] = 1
        attention_mask[r_start : r_end, r_start : r_end] = 1

        # Activate Attention
        # Activate Tag to Cap
        if self.un_mask_t_c is True:
            attention_mask[l_start : l_end, c_start : c_end] = 1
        # Activate Image to Cap
        if self.un_mask_i_c is True:
            attention_mask[r_start : r_end, c_start : c_end] = 1

        # Delete Attention
        # Delete Cap to Tag
        if self.mask_c_t is True:
            attention_mask[:l_start, l_start:r_start] = 0
            # Delete Cap to Tag
        if self.mask_c_i is True:
            attention_mask[:l_start, r_start:] = 0
            # Delete Cap to Tag
        if self.mask_t_t is True:
            attention_mask[l_start:r_start, l_start:r_start] = 0
            # Delete Cap to Tag
        if self.mask_t_i is True:
            attention_mask[l_start:r_start, r_start:] = 0
            # Delete Cap to Tag
        if self.mask_i_t is True:
            attention_mask[r_start:, l_start:r_start] = 0
            # Delete Cap to Tag
        if self.mask_i_i is True:
            attention_mask[r_start:, r_start:] = 0

        # Delete inter-prefixs attentions
        if self.mask_inter_prefix:
            attention_mask[l_start:r_start, l_start:r_start] = torch.eye(r_start-l_start)

        input_ids = torch.tensor(input_ids, dtype=torch.long)
        segment_ids = torch.tensor(segment_ids, dtype=torch.long)

        if self.MLM_train:
            if self.un_mask_t_c:
                assert tokens.__str__().count('[MASK]') in [0, 1]

            masked_ids = torch.tensor(masked_ids, dtype=torch.long)
            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'token_type_ids': segment_ids,
                'img_feats': img_feat,
                'masked_pos': masked_pos,
                'masked_ids': masked_ids
            }

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': segment_ids,
            'img_feats': img_feat,
            'masked_pos': masked_pos
        }