# get belief and sentiment links from an annotation file
#    (and format them consistently)

# Authors: Vlad Niculae <vlad@vene.ro>, Rishi Bommasani <rishibommasani@gmail.com>
# License: BSD 3-clause

import os
import warnings
import json
from collections import ChainMap
from pathlib import Path

import numpy as np
from sklearn.preprocessing import LabelEncoder

from best.best_evaluator import (read_best_xml, read_ere_xml,
                            get_private_state_tuples)
from best.deft_best import Belief, Sentiment, BeStAnnotations


NO_SRC_ID = None

sentiment_lbl_enc = LabelEncoder().fit(['pos', 'neg', 'none'])
belief_lbl_enc = LabelEncoder().fit(['cb', 'ncb', 'rob', 'na'])

def belief_encoder(ys):
    belief_ixs = {'na':3, 'rob':2, 'ncb':1, 'cb':0}
    return [belief_ixs[y] for y in ys]

def belief_encoder_inverter(ys):
    belief_ixs_inv = {3:'na', 2:'rob', 1:'ncb', 0:'cb'}
    return [belief_ixs_inv[y] for y in ys]

def sentiment_encoder(ys):
    sentiment_ixs = {'none':2, 'neg':1, 'pos':0}
    return [sentiment_ixs[y] for y in ys]

def sentiment_encoder_inverter(ys):
    sentiment_ixs_inv = {2:'none', 1:'neg', 0:'pos'}
    return [sentiment_ixs_inv[y] for y in ys]


def _src_match(src_id, opinion):
    if opinion.source is None:
        return src_id == NO_SRC_ID
    else:
        return opinion.source.entity.entity_id == src_id


class _BaseBestFile(object):
    def __init__(self, doc_id, data_root, gold_ere=True):
        self.doc_id = doc_id
        self.data_root = Path(data_root)
        self.gold_ere = gold_ere

        with open(self._build_source_fname(), encoding='utf8') as f:
            self.source = f.read()

        # load entity, relation, event data
        self.evaluator_ere = read_ere_xml(self._build_ere_fname())

        self.mentions = ChainMap(self.evaluator_ere.entity_mentions,
                                 self.evaluator_ere.relation_mentions,
                                 self.evaluator_ere.event_mentions)

        # preconstruct list of (src_entity, trg_mention) pairs
        (self.pairs_entity,
         self.pairs_relation,
         self.pairs_event) = self._get_pairs()

        ann_fname = self._build_annotation_fname()
        if ann_fname:
            # load belief and sentiment annotation
            self.evaluator_best = read_best_xml(
                self.evaluator_ere,
                self._build_annotation_fname())

        # cached attributes
        self._token_begin_ix = None
        self._token_end_ix = None
        self._evaluator_ere = None
        self._evaluator_best = None

        # read tokenization etc
        tok_json_fname, parsed_fname = self._build_source_ann_fnames()
        try:
            with open(tok_json_fname, encoding='utf8') as f:
                self.tokenized = json.load(f)

            with open(parsed_fname, encoding='utf8') as f:

                # split into sentences
                conll = f.read().split('\n\n')[:-1]

                # split each sentence into tokens
                conll = [sent.splitlines() for sent in conll]

                # split each token into its attributes
                conll = [[tok.split('\t') for tok in sent] for sent in conll]

                self.conll = conll

        except FileNotFoundError:
            warnings.warn(
                'Parsed files not found.  Searching in e.g. {} & {}'.format(
                    tok_json_fname, parsed_fname))

    # base functions for constructing file paths from a root
    # implemented this way because the conventions change in the data

    def _build_source_fname(self, dir="source"):
        raise NotImplementedError

    def _build_ere_fname(self):
        if self.gold_ere:
            return self._build_gold_ere_fname()
        else:
            return self._build_predicted_ere_fname()

    def _build_gold_ere_fname(self):
        return os.path.join(self.data_root,
                            'ere',
                            '{}.rich_ere.xml'.format(self.doc_id))

    def _build_predicted_ere_fname(self):
        return os.path.join(
            self.data_root,
            'predicted_ere',
            '{}.predicted.rich_ere.xml'.format(self.doc_id))

    def _build_annotation_fname(self):
        '''return os.path.join(self.data_root,
                            'annotation',
                            '{}.best.xml'.format(self.doc_id))'''
        return os.path.join(
            self.data_root,
            'annotation',
            '{}.best.xml'.format(self.doc_id))

    def _build_source_ann_fnames(self):
        base = self._build_source_fname(dir=os.path.join("source_ann",
                                                         "parsed"))
        return base + ".json", base + ".conll8.parsed"

    # If word isn't in conll, left and right will be equal
    def offset_to_flat_tokens(self, offset, length, include_boundaries=False):

        if include_boundaries:
            raise NotImplementedError

        left = np.searchsorted(self.token_end_ix, offset, side='right')
        right = np.searchsorted(self.token_begin_ix, offset + length)

        return left, right

    def relation_mention_to_span(self, relation_mention):
        start_end_list = []
        for arg in (relation_mention.rel_arg1, relation_mention.rel_arg2):
            if arg.is_filler:
                start_end_list.extend([arg.entity.offset,
                                       arg.entity.offset + arg.entity.length])
            else:
                if arg.entity_mention:
                    start_end_list.extend([arg.entity_mention.offset,
                                           arg.entity_mention.offset +
                                           arg.entity_mention.length])
        if relation_mention.trigger:
            start_end_list.extend([relation_mention.trigger.offset,
                                   relation_mention.trigger.offset +
                                   relation_mention.trigger.length])
        offset = min(start_end_list)
        end = max(start_end_list)
        length = end - offset + 1
        return offset, length

    def event_mention_to_span(self, event_mention):
        start_end_list = []

        for mention_id in event_mention.arguments:
            arg = event_mention.arguments[mention_id]
            if arg.is_filler:
                start_end_list.extend([arg.entity.offset,
                                       arg.entity.offset +
                                       arg.entity.length])
            else:
                if arg.entity_mention:
                    start_end_list.extend([arg.entity_mention.offset,
                                           arg.entity_mention.offset +
                                           arg.entity_mention.length])
        if event_mention.trigger:
            start_end_list.extend([event_mention.trigger.offset,
                                   event_mention.trigger.offset +
                                   event_mention.trigger.length])
        offset = min(start_end_list)
        end = max(start_end_list)
        length = end - offset + 1
        return offset, length

    def get_post_id(self, offset):
        return ""

    def get_author(self, offset):
        return ""

    def path_to_root(self, sent_id, tok_id, guard=0):

        if guard > 100:
            print("max depth")
            return []

        tok = self.conll[sent_id][tok_id]

        parent_id = int(tok[6])
        if parent_id != 0:
            return self.path_to_root(sent_id, parent_id - 1, guard + 1) + [
                self.conll[sent_id][tok_id]]
        else:
            return [self.conll[sent_id][tok_id]]

    @property
    def token_begin_ix(self):
        if self._token_begin_ix is None:
            self._token_begin_ix = [tok['characterOffsetBegin']
                                    for sent in self.tokenized['sentences']
                                    for tok in sent['tokens']]
        return self._token_begin_ix

    @property
    def token_end_ix(self):
        if self._token_end_ix is None:
            self._token_end_ix = [tok['characterOffsetEnd']
                                  for sent in self.tokenized['sentences']
                                  for tok in sent['tokens']]
        return self._token_end_ix

    # Methods useful for training models

    def _get_pairs(self):

        toward_entity, toward_relation, toward_event = [], [], []

        for trg_mention_id, trg_mention in sorted(self.evaluator_ere.entity_mentions.items()):
            for src_entity_id in [NO_SRC_ID] + sorted(self.evaluator_ere.entities):
                toward_entity.append((src_entity_id, trg_mention_id))

        for trg_mention_id in sorted(self.evaluator_ere.relation_mentions):
            for src_entity_id in [NO_SRC_ID] + sorted(self.evaluator_ere.entities):
                toward_relation.append((src_entity_id, trg_mention_id))

        for trg_mention_id in sorted(self.evaluator_ere.event_mentions):
            for src_entity_id in [NO_SRC_ID] + sorted(self.evaluator_ere.entities):
                toward_event.append((src_entity_id, trg_mention_id))

        return toward_entity, toward_relation, toward_event

    def sentiment_labels(self, pairs):
        y = []
        if not pairs:
            return np.array(y, dtype=np.long)

        for src_ix, trg_ix in pairs:
            trg = self.mentions[trg_ix]
            this_y = 'none'
            for sentiment in trg.sentiments:
                if sentiment.polarity != 'none' and _src_match(src_ix,
                                                               sentiment):
                    this_y = sentiment.polarity

                    break  # TODO: can we ever have more than one?
            y.append(this_y)
        return np.array(sentiment_encoder(y))

    def belief_labels(self, pairs):
        y = []
        if not pairs:
            return np.array(y, dtype=np.long)
        for src_ix, trg_ix in pairs:
            trg = self.mentions[trg_ix]
            this_y = 'na'
            for belief in trg.beliefs:
                if belief.belief_type != 'na' and _src_match(src_ix,
                                                               belief):
                    this_y = belief.belief_type
                    break  # TODO: can we ever have more than one?
            y.append(this_y)
        return np.array(belief_encoder(y))

    # Code for evaluation: impedance matching with the evaluator's PSTs

    def state_tuples(self, predictions=None, sentiment=True, belief=True,
                     include_null_src=True):
        if predictions is None:
            predictions = self.evaluator_best
        return get_private_state_tuples(self.evaluator_ere,
                                        predictions,
                                        sentiment=sentiment,
                                        belief=belief,
                                        null_source_flag=not include_null_src)

    def build_pseudo_annotations(self, beliefs=None, sentiments=None):
        # expects list of (src_mention_id or None, trg_mention_id, label)

        ere = self.evaluator_ere
        mentions = {}
        mentions.update(ere.entity_mentions)
        mentions.update(ere.relation_mentions)
        mentions.update(ere.event_mentions)

        if beliefs is None:
            beliefs = []

        if sentiments is None:
            sentiments = []

        ann = BeStAnnotations(self.evaluator_ere)
        # src in beliefs is an entity_id
        for (src, trg, btype) in beliefs:
            src_mention = None
            if src:
                src_mention = ere.entities[src].mentions[0]
            belief = Belief(src_mention, mentions[trg], btype=btype, polarity='none', sarcasm='no')
            ann.beliefs.append(belief)
        # src in sentiments is a entity_id
        for (src, trg, polarity) in sentiments:
            src_mention = None
            if src:
                src_mention = ere.entities[src].mentions[0]
            sentiment = Sentiment(src_mention, mentions[trg], polarity=polarity, sarcasm="no")
            ann.sentiments.append(sentiment)
        return ann


class _BaseForumFile(_BaseBestFile):
    data_type = 'forum'

    def _get_post_line(self, offset):
        if ('<headline>' in self.source[:offset] and
                '</headline>' in self.source[offset:]):
            post_line_start = self.source.find("<post ", offset)
        else:
            post_line_start = self.source.rfind("<post ", 0, offset)
        post_line_end = self.source.find("\n", post_line_start)
        post_line = self.source[post_line_start:post_line_end]
        assert "\n" not in post_line
        return post_line

    def get_post_id(self, offset):
        post_line = self._get_post_line(offset)

        id_start = post_line.find('id="') + 4
        id_end = post_line.find('"', id_start)
        post_id = post_line[id_start:id_end]

        return post_id

    # docs with a post that doesn't have an author (that are forums)
    # ENG_DF_001487_20131215_G00A0AW5V
    # ENG_DF_001487_20140127_G00A0AW5X
    def get_author(self, offset):

        post = self._get_post_line(offset)
        if 'author=' in str(post):
            ix = post.index("author=")
            post = post[ix + len("author='"):]
            ix = post.index('"')
            return post[:ix]
        return None


class BestOldForumFile(_BaseForumFile):
    """Hexadecimal doc_id, old forum files"""

    def _build_source_fname(self, dir="source"):
        return os.path.join(self.data_root,
                            dir,
                            '{}.cmp.txt'.format(self.doc_id))


class BestEvalOldForumFile(_BaseForumFile):
    """Hexadecimal doc_id, old forum files"""

    def _build_source_fname(self, dir=Path("df/source")):
        return os.path.join(self.data_root,
                            dir,
                            '{}.xml'.format(self.doc_id))

    def _build_gold_ere_fname(self):
        return os.path.join(self.data_root,
                            'df',
                            'ere',
                            '{}.rich_ere.xml'.format(self.doc_id))

    def _build_predicted_ere_fname(self):
        return os.path.join(
            self.data_root,
            'df',
            'predicted_ere',
            '{}.predicted.rich_ere.xml'.format(self.doc_id))

    def _build_source_ann_fnames(self):
        base = self._build_source_fname(dir=os.path.join("df",
                                                         "parsed"))
        return base + ".json", base + ".conll8.parsed"

    def _build_annotation_fname(self):
        return None


class BestNewForumFile(_BaseForumFile):
    """New-style forum files with multiple annotations per source file"""

    def _build_source_fname(self, dir="source"):
        if self.doc_id.count("_") == 5:
            src_id, _ = self.doc_id.rsplit("_", 1)
        else:
            src_id = self.doc_id
        return os.path.join(self.data_root,
                            dir,
                            '{}.xml'.format(src_id))


class BestNewswireFile(_BaseBestFile):
    """Newswire data"""

    data_type = 'news'

    def _build_source_fname(self, dir="source"):
        return os.path.join(self.data_root,
                            dir,
                            '{}.xml'.format(self.doc_id))


# arguments (to events/relations) can be either entity mentions or fillers.

# This function checks whether an argument is an entity or a filler,
# and finds its offset and length


def argument_offset(arg, entity_mentions, fillers):
    if 'entity_mention_id' in arg.attrib:
        arg_ent = entity_mentions[arg.attrib['entity_mention_id']]
        offset = int(arg_ent['mention_offset'])
        length = int(arg_ent['mention_length'])
    elif 'filler_id' in arg.attrib:
        filler = fillers[arg.attrib['filler_id']]
        offset = int(filler['offset'])
        length = int(filler['length'])
    else:
        print(arg.attrib)
        return None
    return offset, length
