

import logging, os
from milie.util import write_json_to_file
from milie.dataset_handlers.dataset_bitext import BitextHandler, GenExample
from milie.util import sublist_start_index
import numpy as np
from milie.evals.carb_evaluator import Evaluate
import json, operator
LOGGER = logging.getLogger(__name__)
from tqdm import tqdm
from collections import defaultdict

class MilieHandler(BitextHandler):
    """
    Handles the OPIEC openIE + Rel. Extraction dataset
    """
    # pylint: disable=too-many-instance-attributes
    def __init__(self, milie_args):
        """
        mum_labels: the number of labels for the classification on [CLS]
        text2id: dictionary where keys are text and values are a label for
        the classification in [0, num_labels]
        id2text: reverse of text2id
        examples: a list of examples of type :py:class:MilieExample
        features: a list of features of type :py:class:GenInputFeature
        """
        super().__init__(milie_args)
        self.examples = []
        self.features = []
        self._text2id = [{"False":0,"True":1}]
        self._text2id_tok = [{"B":2,"I":1,"O":0},
                             {"B":2,"I":1,"O":0},
                             {"B":2,"I":1,"O":0},
                             {"B":2,"I":1,"O":0}]
        self.write_predictions = write_json_to_file
        self.write_eval = write_json_to_file
        self.write_tok_predictions = write_json_to_file

        self.dep_labels = {'ROOT':1,'acl':2,'advcl':3,'advmod':4,'amod':5,'appos':6, 'aux':7, 'case':8,
                           'cc':9,'ccomp':10,'compound':11,'conj':12,'csubj':13,'cop':14,'dep':15,'det':16,
                           'expl':17,'fixed':18,'flat':19,'iobj':20,'mark':21,'nmod':22,'nsubj':23,'nummod':24,
                           'obj':25, 'obl':26,'parataxis':27, 'punct':28,'xcomp':29}
        self.output_dir = milie_args.output_dir

    def process_inputs(self, inputs):
        '''
        Converts the input data to :py:class:OpiecExample object

        :param inputs: list of input sentences
        :return: 0 on success
        '''
        self.examples = []
        for count, sent in enumerate(inputs):
            sent,dep = sent
            example = MilieHandler.MilieExample(count,sent,"", [], dep)
            self.examples.append(example)
        self.subj_tags, self.pred_tags, self.obj_tags = self._get_tags()
        return 0

    def read_examples(self, is_training=True):
        """
        Reads a oie dataset, each entry in self.examples holds a :py:class:OpiecExample object
        :param input_file: the file containing the data (json)
        :param is_training: True for training
        :return: 0 on success
        """
        if is_training is True:
            input_file = self.train_file
        else:
            input_file = self.predict_file
        self.examples = []
        with open(input_file,encoding='utf8') as f:
            examples = json.load(f)
        for count, ex in enumerate(tqdm(examples, desc="Reading File")):
            if ex['sentence'] is None:
                continue
            if is_training:
                if ex['head']=='CLS' or len(ex['targets'])<=0:
                    continue
                example = MilieHandler.MilieExample(count, ex['sentence'], ex['targets'], ex['head'], dep_tags=ex['dep'])
            else:
                example = MilieHandler.MilieExample(count, ex['sentence'], [], "",dep_tags= ex['dep'], pred_elem=ex.get('pred_elem',None),pred_map=ex.get('pred_map',None))
            self.examples.append(example)
            #if count>5000:
                #break

        self.subj_tags, self.pred_tags, self.obj_tags = self._get_tags()
        return 0

    def _get_start_ids(self, text, input_ids, tokenizer):
        start_ids, tokens = [], []
        toks = tokenizer.tokenize(text)
        ids = tokenizer.convert_tokens_to_ids(toks)
        offset = 0
        while True:
            start_id = sublist_start_index(ids, input_ids)
            if start_id is not None:
                tokens.append(toks)
                start_ids.append((start_id+offset, len(ids)))
                offset += start_id + len(ids)
                input_ids = input_ids[start_id + len(ids):]
            else:
                break
        return tokens, start_ids

    def _get_token_classify_ids(self, targets, input_ids, tokenizer):
        all_start_ids = []
        for target in targets:
            _, start_ids = self._get_start_ids(target,input_ids, tokenizer)
            all_start_ids.append(start_ids)
        return  all_start_ids

    def encode_words(self, example, input_ids):
        classify_id_tokens = np.asarray([0] * len(input_ids), dtype='int64')
        if 102 in input_ids:
            idx = input_ids.index(102)
            classify_id_tokens[idx:] = -1
        classify_id_tokens[0] = -1
        seen_words = dict()
        for word, tag in zip(example.tokens, example.tags):
            label = self._text2id_tok[0][tag]
            tokens = self.tokenizer.tokenize(word)
            ids = self.tokenizer.convert_tokens_to_ids(tokens)
            start_id = sublist_start_index(ids, input_ids)
            if start_id is None:
                continue
            if word in seen_words:
                start_id = seen_words[word]
                input_ids_cp = input_ids[start_id + 1:]
                offset = start_id + 1
                start_id = sublist_start_index(ids, input_ids_cp)
                if start_id is None:
                    continue
                start_id += offset
            assert input_ids[start_id] == ids[0]
            seen_words[word] = start_id
            classify_id_tokens[start_id: start_id + len(ids)] = label
        if 102 in input_ids:
            idx = input_ids.index(102)
            classify_id_tokens[idx:] = -1
        classify_id_tokens[0] = -1
        return classify_id_tokens

    def encode_chars(self, example, input_ids):
        classify_id_tokens = np.asarray([0] * len(input_ids), dtype='int64')
        all_ids = self._get_token_classify_ids(example.targets, input_ids, self.tokenizer)
        for ids in all_ids:
            for count, id in enumerate(ids):
                start, offset = id
                classify_id_tokens[start:start + offset] = 1
                classify_id_tokens[start] = 2
        if 0 in input_ids:
            pad_idx = input_ids.index(0)
            classify_id_tokens[pad_idx:] = -1
        return classify_id_tokens

    def get_token_classification_ids(self, example, input_ids):

        classify_id_tokens = self.encode_chars(example, input_ids)
        heads = [np.asarray([-1] * len(input_ids), dtype='int64'),
                 np.asarray([-1] * len(input_ids),dtype='int64'),
                 np.asarray([-1] * len(input_ids),dtype='int64')]
        if example.head == 'subject':
            heads = [classify_id_tokens,heads[0],heads[1], heads[2]]
        elif example.head=='predicate':
            heads = [heads[0],classify_id_tokens,heads[1], heads[2]]
        elif example.head=='object':
            heads = [heads[0],heads[1],classify_id_tokens, heads[2]]
        elif example.head == 'arguments':
            heads.append(classify_id_tokens)
        elif example.head =='CLS':
            heads.append( np.asarray([-1] * len(input_ids),dtype='int64'))
        else:
            raise RuntimeError()
        return heads

    def get_segment_ids(self, example, input_ids):
        def tag_phrase(segment_ids,phrases, label):
            for phrase in phrases:
                _, start_ids = self._get_start_ids(phrase.strip(), input_ids, self.tokenizer)
                for pair in start_ids:
                    start_id, length = pair
                    segment_ids[start_id: start_id + length] = label
            return segment_ids
        segment_ids = np.asarray([0] * len(input_ids),dtype='int64')
        return segment_ids
        #for dep in example.dep_tags:
        #    if len(dep)<=1:
        #        continue
        #    segment_ids = tag_phrase(segment_ids,dep[1:], self.dep_labels[dep[0].strip()])
        #return segment_ids

    def _get_tags(self):
        subj, pred, obj = '<A0>', '<P>','<A1>'
        subj_ids = [self.tokenizer.convert_tokens_to_ids(x) for x in self.tokenizer.tokenize(subj)]
        pred_ids = [self.tokenizer.convert_tokens_to_ids(x) for x in self.tokenizer.tokenize(pred)]
        obj_ids = [self.tokenizer.convert_tokens_to_ids(x) for x in self.tokenizer.tokenize(obj)]
        return subj_ids, pred_ids, obj_ids

    def _ids_tokens(self, ids):
        return self.tokenizer.convert_ids_to_tokens(ids)

    def _sent_ids(self, sent):
        tokens = self.tokenizer.tokenize(sent)
        return self.tokenizer.convert_tokens_to_ids(tokens)

    def token_to_words(self, input_ids, classification_tokens):
        def extract_word(tokens, ids):
            count = 0
            ext_words = []
            pred_ids = []
            while count < len(tokens):
                tok, tag = tokens[count]
                if tag != 1 and count > 0:
                    break
                if tok.startswith('#'):
                    if len(ext_words) > 0:
                        ext_words[-1] = ext_words[-1] + [tok.replace('#', '')]
                    else:
                        ext_words.append([tok.replace('#', '')])
                else:
                    ext_words.append([tok])
                pred_ids.append(ids[count])
                count += 1
            pred_words = ' '.join([''.join(w) for w in ext_words])
            return pred_words, pred_ids, count

        if 102 in input_ids:
            idx = input_ids.index(102)
            input_ids = input_ids[:idx]
            classification_tokens = classification_tokens[:idx]
        tokens = [(self.tokenizer.ids_to_tokens[t[0]], t[1]) for t in zip(input_ids, classification_tokens)]
        tagged_words, all_pred_ids, index = [], [], 0
        while index < len(tokens):
            tok, tag = tokens[index]
            if tag == 2:
                tagged_word, pred_ids, offset = extract_word(tokens[index:], input_ids[index:])
                tagged_word = tagged_word.strip().replace(' . ','.').replace(" ' "," '")\
                        .replace(" In "," in ").replace(". ",".").replace(". ","")\
                        .replace('In ','in ').replace(' - ','-').replace(" .",".").replace(" / ","/")
                tagged_words.append(f"{tagged_word}")
                all_pred_ids.append(pred_ids)
                index += offset
            else:
                index += 1
        assert len(tagged_words) == len(all_pred_ids)
        n_tokens = [tok[0] for tok in tokens[1:]]
        return tagged_words, all_pred_ids, n_tokens

    def decode_words(self, current_example, classification_tokens, input_ids):
        def majority(arr):
            if len(arr) <= 0:
                return 0
            if len(arr) == 1:
                return arr[0]
            elems = set(arr)
            if 2 in elems:
                return 2
            elif 1 in elems:
                return 1
            return 0
            # sorted_freq = sorted(elem_freq.items(), key=operator.itemgetter(1), reverse=True)[0]
            # return sorted_freq[0]

            # input_ids = input_ids.tolist()

        if 102 in input_ids:
            idx = input_ids.index(102)
            classification_tokens = classification_tokens[1:idx]
        else:
            classification_tokens = classification_tokens[1:]
            # 1 create token boundaries
        boundaries = [len(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(word))) for word in
                      current_example.tokens]
        assert len(boundaries) == len(current_example.tokens)
        assert len(current_example.tokens) == len(current_example.tags)
        start_id = 0
        predicted_tags, pred_ids = [], []
        for bound in boundaries:
            tag = classification_tokens[start_id:start_id + bound]
            if len(tag) <= 0:
                predicted_tags.append([0])
                break
            if tag != 0:
                predicted_tags.append(tag)
            start_id += bound
        # 2 use majority to obtain the label
        # predicted_tags.append(classification_tokens[start_id:])
        predicted_tags = [int(majority(x)) for x in predicted_tags]
        # print(f"Length Predicted Tags: {len(predicted_tags)}, Length of Tags: {len(current_example.tags)}")
        assert len(predicted_tags) == len(current_example.tags)
        index = 0
        all_elems = []
        while index < len(current_example.tokens):
            tok, tag = current_example.tokens[index], predicted_tags[index]
            if tag == 2:
                elem = [current_example.tokens[index]]
                for i in range(index + 1, len(predicted_tags)):
                    if predicted_tags[i] != 1:
                        break
                    elem.append(current_example.tokens[i])
                all_elems.append(' '.join(elem))
            index += 1
        pred_ids = [self._sent_ids(elem) for elem in all_elems]
        return all_elems, pred_ids

    def arrange_token_classify_output(self, current_example, classification_tokens, input_ids):
        """
        Simply returns all classification elements, other data sets can arrange the output as needed here.
        :param current_example: The current example
        :param classification_tokens: the classification labels for all tokens
        :return: classification_tokens
        """
        if len(input_ids)==1:
            input_ids = input_ids[0]
        pred_words, pred_ids, tokens = self.token_to_words(input_ids, classification_tokens)
        return pred_words, pred_ids, tokens


    def select_deciding_score(self, results):

        if results['f1'] is not None:
            return float(results['f1'])
        return 0.0

    def evaluate(self, output_prediction_file, valid_gold, mode='token'):
        eval = Evaluate()
        results, all_results =  eval.evaluate(output_prediction_file, valid_gold)
        if all_results is not None:
            all_results = [str(x) for x in all_results]
            with open(os.path.join(self.output_dir,'all_results.tsv'),'w', encoding='utf8') as f:
                f.write('\n'.join(all_results))

        return results

    class MilieExample(GenExample):
        """A single training/test example from Milie dataset.
        """

        def __init__(self, example_index, sentence, targets, head,
                     dep_tags = None,pred_elem = None, pred_map=None):
            super().__init__()
            self.example_index = example_index
            self.head = head
            self.targets = targets
            self.sentence = sentence
            self.part_a = sentence
            self.part_b = ""
            self.pred_elem = pred_elem
            self.pred_map = pred_map
            self.dep_tags = dep_tags

        def __str__(self):
            return self.__repr__()

        def __repr__(self):
            collect_string = "example_index: %s" % self.example_index
            return collect_string
