# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import csv
import json
import glob
import logging
import pandas as pd

import numpy as np
import torch
from torch.utils.data import Dataset

from ._image_features_reader import ImageFeaturesH5Reader

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


def assert_eq(real, expected):
    assert real == expected, "%s (true) vs %s (expected)" % (real, expected)


def _create_entry(question, answer):
    answer.pop("image_id")
    answer.pop("question_id")
    entry = {
        "question_id": question["question_id"],
        "image_id": question["image_id"],
        "question": question["question"],
        "answer": answer,
    }
    return entry


def _load_dataset(dataroot, name, label_type):
    """Load entries

    dataroot: root path of dataset
    name: 'train', 'val', 'test'
    """
    data_name_dict = {
        "train": "evidX_intact_train.jsonl",
        "val": "evidX_intact_dev.jsonl",
        "test": "evidX_intact_test.jsonl",
    }
    data_path = os.path.join(dataroot, data_name_dict[name])
    assert os.path.exists(data_path)
    
    df = pd.read_json(data_path, lines=True)
    print (df)

    if name == "train" or name == "val" or name == "test":
        ids = df.index.tolist()
        captions = df["caption"].tolist()
        labels = df[label_type].tolist()
        image_ids = df["image_id"].tolist()

        questions = []
        answers = []

        for i in range(len(ids)):
            data_curr = {
                "question_id": ids[i],
                "image_id": image_ids[i],
                "question": captions[i],
            }
            answer_curr = {
                "question_id": ids[i],
                "image_id": image_ids[i],
                "labels": labels[i],
            }
            questions.append(data_curr)
            answers.append(answer_curr)

    assert_eq(len(questions), len(answers))

    entries = []
    for question, answer in zip(questions, answers):
        assert_eq(question["question_id"], answer["question_id"])
        assert_eq(question["image_id"], answer["image_id"])
        entries.append(_create_entry(question, answer))

    return entries


class IntactClassificationDataset(Dataset):
    def __init__(
        self,
        task,
        dataroot,
        annotations_jsonpath,
        split,
        image_features_reader,
        gt_image_features_reader,
        tokenizer,
        bert_model,
        clean_datasets,
        padding_index=0,
        max_seq_length=16,
        max_region_num=101,
        label_type=None,
    ):
        super().__init__()
        self.split = split
        image_features_path = annotations_jsonpath
        self.image_features_path = image_features_path 
        self._max_region_num = max_region_num
        self._max_seq_length = max_seq_length
        self._tokenizer = tokenizer
        self._padding_index = padding_index
        self.label_type = label_type
        
        label_dict = {
            "p_meth_label": 8,
            "i_meth_label": 18,
            "p_meth": 48,
            "i_meth": 122,
        }
        assert label_type in label_dict
        self.num_labels = label_dict[label_type]

        print ("[INFO] Data Root: {}".format(dataroot))
        print ("[INFO] Image Features Path: {}".format(image_features_path))
        print ("[INFO] Label Type: {}".format(label_type))
        print ("[INFO] Split: {}".format(split))

        self.entries = _load_dataset(dataroot, split, label_type)

        self.tokenize(max_seq_length)
        self.tensorize()

        feats_path = os.path.join(image_features_path, "*.npy")
        all_npys = glob.glob(feats_path)

        self.all_feats = {}
        for npy in all_npys:
            curr_data = np.load(npy, allow_pickle=True)
            curr_data = curr_data.item()
            image_id = npy.split("/")[-1].split(".")[0]
            curr_data["image_id"] = image_id
            assert image_id not in self.all_feats
            self.all_feats[image_id] = curr_data

        print ("[INFO] Loading {} Intact Dataset... Done!".format(split))
        print ("[INFO] Number of {} Intact Data Instances: {}\n".format(split, len(self.entries)))

    def _image_features_reader(self, image_id):
        features = self.all_feats[image_id]["features"]
        num_boxes = self.all_feats[image_id]["num_boxes"]
        # TODO: process the bboxes
        boxes = self.all_feats[image_id]["bbox"]
        image_w, image_h = self.all_feats[image_id]["image_width"], self.all_feats[image_id]["image_height"]
        image_location = np.zeros((num_boxes, 5), dtype=np.float32)
        image_location[:num_boxes, :4] = boxes
        image_location[:, 4] = (
            (boxes[:, 3] - boxes[:, 1])
            * (boxes[:, 2] - boxes[:, 0])
            / (float(image_w) * float(image_h))
        )
        image_location[:, 0] = image_location[:, 0] / float(image_w)
        image_location[:, 1] = image_location[:, 1] / float(image_h)
        image_location[:, 2] = image_location[:, 2] / float(image_w)
        image_location[:, 3] = image_location[:, 3] / float(image_h)
        boxes = image_location
        return features, num_boxes, boxes, None

    def tokenize(self, max_length=16):
        """Tokenizes the questions.

        This will add q_token in each entry of the dataset.
        -1 represent nil, and should be treated as padding_index in embedding
        """
        for entry in self.entries:
            tokens = self._tokenizer.encode(entry["question"])
            tokens = tokens[: max_length - 2]
            tokens = self._tokenizer.add_special_tokens_single_sentence(tokens)

            segment_ids = [0] * len(tokens)
            input_mask = [1] * len(tokens)

            if len(tokens) < max_length:
                # Note here we pad in front of the sentence
                padding = [self._padding_index] * (max_length - len(tokens))
                tokens = tokens + padding
                input_mask += padding
                segment_ids += padding

            assert_eq(len(tokens), max_length)
            entry["q_token"] = tokens
            entry["q_input_mask"] = input_mask
            entry["q_segment_ids"] = segment_ids

    def tensorize(self):

        for entry in self.entries:
            question = torch.from_numpy(np.array(entry["q_token"]))
            entry["q_token"] = question

            q_input_mask = torch.from_numpy(np.array(entry["q_input_mask"]))
            entry["q_input_mask"] = q_input_mask

            q_segment_ids = torch.from_numpy(np.array(entry["q_segment_ids"]))
            entry["q_segment_ids"] = q_segment_ids

            answer = entry["answer"]
            labels = np.array(answer["labels"])
            labels = torch.from_numpy(labels)
            entry["answer"]["labels"] = labels

    def __getitem__(self, index):
        entry = self.entries[index]
        image_id = entry["image_id"]
        question_id = entry["question_id"]
        features, num_boxes, boxes, _ = self._image_features_reader(image_id)

        mix_num_boxes = min(int(num_boxes), self._max_region_num)
        mix_boxes_pad = np.zeros((self._max_region_num, 5))
        mix_features_pad = np.zeros((self._max_region_num, 2048))

        image_mask = [1] * (int(mix_num_boxes))
        while len(image_mask) < self._max_region_num:
            image_mask.append(0)

        # shuffle the image location here.
        # img_idx = list(np.random.permutation(num_boxes-1)[:mix_num_boxes]+1)
        # img_idx.append(0)
        # mix_boxes_pad[:mix_num_boxes] = boxes[img_idx]
        # mix_features_pad[:mix_num_boxes] = features[img_idx]

        mix_boxes_pad[:mix_num_boxes] = boxes[:mix_num_boxes]
        mix_features_pad[:mix_num_boxes] = features[:mix_num_boxes]

        features = torch.tensor(mix_features_pad).float()
        image_mask = torch.tensor(image_mask).long()
        spatials = torch.tensor(mix_boxes_pad).float()

        question = entry["q_token"]
        input_mask = entry["q_input_mask"]
        segment_ids = entry["q_segment_ids"]

        co_attention_mask = torch.zeros((self._max_region_num, self._max_seq_length))

        answer = entry["answer"]
        target = answer["labels"]

        return (
            features,
            spatials,
            image_mask,
            question,
            target,
            input_mask,
            segment_ids,
            co_attention_mask,
            # image_id,
            question_id,
        )

    def __len__(self):
        return len(self.entries)
