#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import os, json, spacy
from collections import defaultdict
from spacy.matcher import Matcher
from tqdm import tqdm

def format_benchie_extraction():
    lang = 'ar'
    data_dir = "C:\\code\\benchie\\"
    file_name = 'predictions.tok'
    with open(os.path.join(data_dir, file_name), encoding='utf8') as f:
        extractions = json.load(f)
    with open(os.path.join(data_dir, f'sample100_{lang}.txt'), encoding='utf8') as f:
        sents = [line.strip() for line in f]
    write_buff = []
    for count, sent in enumerate(sents):
        triples = extractions.get(sent,[])
        for trp in triples:

            trp = [x for x in trp if x != '']
            if len(trp)>4:
                for obj in trp[2:-1]:
                    write_buff.append(f"{count+1}\t{trp[0]}\t{trp[1]}\t{obj}")
            else:
                write_buff.append(f"{count+1}\t{trp[0]}\t{trp[1]}\t{trp[2]}")
    with open(f'C:\\code\\benchie\\data\\oie_systems_explicit_extractions\\milie_{lang}_explicit.txt','w', encoding='utf8') as f:
        f.write('\n'.join(write_buff))



def remove_tokenization(x):
    if not isinstance(x,str):
        return x
    x = x.split(' ')
    return ''.join([elem for elem in x])

def muti2oie_to_benchie():
    data_dir = "C:\\code\\benchie\\"
    lang = 'zh'
    file_name = f'extraction_{lang}_300.json'
    with open(os.path.join(data_dir, file_name), encoding='utf8') as f:
        extractions = json.load(f)
    with open(os.path.join(data_dir, f'sample100_{lang}.txt'), encoding='utf8') as f:
        sents = [line.strip() for line in f]
    write_buff = []
    for count, sent in enumerate(sents):
        triples = extractions.get(sent,[])
        for trp in triples:
            trp = [x for x in trp if x!='']
            if len(trp) > 3:
                #print("***")
                for obj in trp[2:]:
                    write_buff.append(f"{count + 1}\t{trp[0]}\t{trp[1]}\t{obj}")
            else:
                if len(trp)>2:
                    write_buff.append(f"{count + 1}\t{trp[0]}\t{trp[1]}\t{trp[2]}")
    with open(f'C:\\code\\benchie\\data\\oie_systems_explicit_extractions\\multi2oie_explicit_{lang}.txt', 'w',
              encoding='utf8') as f:
        f.write('\n'.join(write_buff))

def create_benchie_test():
    lang = 'ar'
    file = f'C:\\data\\milie\\en\\sample100_{lang}.txt'
    with open(file, encoding='utf8') as f:
        sents = [line.strip() for line in f]
    dev_data = []
    seen_sents = set()
    lang = 'en'
    model = 'web' if lang in {'en','zh'} else 'news'
    nlp = spacy.load(f'{lang}_core_{model}_md',disable=['ner'])
    dep_tags = extract_dep(sents, nlp, f'dep_map_{lang}.txt')
    for s in sents:
        if s in seen_sents:
            continue
        ex = dict()
        ex['sentence'] = s
        ex['dep'] = dep_tags[s]
        dev_data.append(ex)
        seen_sents.add(s)
    lang = 'ar'
    print(len(seen_sents))
    with open(os.path.join('C:\\data\\milie\\en\\', f'benchie_100_{lang}.json'),'w', encoding='utf8') as f:
        json.dump(dev_data, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)

def extract_dep(text,nlp,dep_map_file):
    dep_map= dict()
    with open(dep_map_file, encoding='utf8') as f:
        for line in f:
            key,val = line.split('-')
            dep_map[key.strip()] = val.strip()
    def _extract_dep_tags(nlp, sentences):
        for doc in nlp.pipe(sentences, batch_size=10000):
            dep = [(token.text, dep_map[token.dep_]) for token in doc if len(token.dep_)>0]
            yield dep
    dep_tags = _extract_dep_tags(nlp, text)
    tag_map = accumulate_tags(dep_tags, text)
    return tag_map

def accumulate_tags(ext_tags, text, order=None):
    all_tags = []
    for tags in tqdm(ext_tags, total=len(text),desc='Extracting Tags'):
        tag_dict = defaultdict(set)
        for tok,tag in tags:
            tag_dict[tag].add(tok)
        tag_list = []
        if order is None:
            order = tag_dict.keys()
        for key in order:
            if key in tag_dict:
                tag_list.append([key]+list(tag_dict[key]))
        all_tags.append(tag_list)
    assert len(all_tags) == len(text)
    pos_map = {text[count]: all_tags[count] for count in range(len(text))}
    return pos_map


def inject_vps():
    data_dir = "C:\\data\\milie\\en\\"
    with open(os.path.join(data_dir,'benchie_300.json'), encoding='utf8') as f:
        data = json.load(f)
    sentences = list(set([ex['sentence'] for ex in data if len(ex['sentence']) > 1]))
    nlp = spacy.load('en_core_web_sm', disable=[ "lemmatizer", 'ner'])
    vp_map,_ = extract_phrases(sentences, nlp)
    for ex in data:
        sent = ex['sentence']
        vps = vp_map[sent]
        ex['pred_elem'] = vps
    with open(os.path.join(data_dir, 'benchie_300_pred.json'), 'w', encoding='utf8') as f:
        json.dump(data, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)


def extract_phrases(text, nlp):

    patterns = [[{'POS':'VERB'}],[{"POS":"ADV"}, {"POS":"VERB"}],[{"POS":"VERB"}, {"POS":"ADV"}],[{"POS":"VERB"}, {"POS":"ADP"}],
                [{'POS':'AUX'}],[{"POS":"ADV"}, {"POS":"AUX"}],[{"POS":"AUX"}, {"POS":"ADV"}],[{"POS":"AUX"}, {"POS":"ADP"}],
                [{'POS':'AUX'},{'POS':'VERB'}]]

    def _ext_verb_phrases(nlp, matcher, sentences):
        for doc in nlp.pipe(sentences, batch_size=10000):
            matches = matcher(doc)
            verb_phrases = [doc[start:end].text for _,start,end in matches]
            nps = [np.text for np in doc.noun_chunks]
            yield verb_phrases, nps


    matcher = Matcher(nlp.vocab)
    for count,pat in enumerate(patterns):
        matcher.add(str(count),[pat])
    print(f"Num. Sentences {len(text)}")
    vps = _ext_verb_phrases(nlp, matcher, text)
    verb_phrases, noun_phrases = [],[]
    for phrases in tqdm(vps, desc='Extracting VPs and NPs', total=len(text)):
        vp,np = phrases
        verb_phrases.append(list(set(vp)))
        noun_phrases.append(list(set(np)))
    assert len(verb_phrases) == len(text)
    vp_map = {text[count]: verb_phrases[count] for count in range(len(text))}
    np_map = {text[count]: noun_phrases[count] for count in range(len(text))}
    return vp_map, np_map



def inject_clauseie_elements():
    elems = ['object']
    #map_1 = ['subject','object']
    map_2 = ['object','predicate']
    #map_3 = ['object','subject']
    data_dir = "C:\\code\\benchie\\"

    with open(os.path.join(data_dir,'benchie_300_zh.json'), encoding='utf8') as f:
        data = json.load(f)
    clauseie_data = defaultdict(list)
    with open(os.path.join(data_dir,'data','oie_systems_explicit_extractions', 'pred_patt_zh_explicit.txt'), encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            trp = parts[1:]
            clauseie_data[parts[0]].append('\t'.join(trp))
    benchie_sents = dict()
    with open(os.path.join(data_dir, 'sample300_zh.txt'), encoding='utf8') as f:
        for count,line in enumerate(f):
            benchie_sents[count+1] = line.strip()
    triples = defaultdict(list)
    for i in clauseie_data.keys():
        trps = clauseie_data[i]
        sent = benchie_sents[int(i)]
        triples[sent] = trps

    for ex in data:
        sent = ex['sentence']
        trps = triples.get(sent,[])
        pred_elem = defaultdict(list)
        pred_map = dict()
        #lev1_pos, lev2_pos = get_elem_pos(map_1[0]), get_elem_pos(map_1[1])
        #ext_map = extract_map((lev1_pos,lev2_pos),trps)
        #pred_map['-'.join([str(lev1_pos),str(lev2_pos)])] = ext_map
        lev1_pos, lev2_pos = get_elem_pos(map_2[0]), get_elem_pos(map_2[1])
        ext_map = extract_map((lev1_pos, lev2_pos), trps)
        pred_map['-'.join([str(lev1_pos), str(lev2_pos)])] = ext_map
        #lev1_pos, lev2_pos = get_elem_pos(map_3[0]), get_elem_pos(map_3[1])
        #ext_map = extract_map((lev1_pos, lev2_pos), trps)
        #pred_map['-'.join([str(lev1_pos), str(lev2_pos)])] = ext_map
        ex['pred_map'] = pred_map
        for elem in elems:
            elem_pos = get_elem_pos(elem)
            elem_list = extract_elem(elem_pos,trps)
            pred_elem[int(elem_pos)] = elem_list
        ex[f'pred_elem'] = pred_elem
    with open(os.path.join(data_dir, f'benchie_300_zh_pp.json'),'w', encoding='utf8') as f:
        json.dump(data, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)

def extract_map(key, trps):
    map = defaultdict(list)
    for trp in trps:
        parts = trp.split('\t')[:3]
        map[parts[key[0]]].append(parts[key[1]])
    return map

def extract_elem(elem_pos, trps):
    elems = []
    for trp in trps:
        parts = trp.split('\t')[:3]
        elems.append(parts[elem_pos].replace('"',''))
    return list(set(elems))

def get_elem_pos(elem):
    if elem=='subject':
        elem_pos = 0
    elif elem=='predicate':
        elem_pos = 1
    elif elem == 'object':
        elem_pos = 2
    return elem_pos


def filter_m2oie_obj():
    data_dir = "C:\\code\\benchie\\data\\oie_systems_explicit_extractions"
    clauseie_data = defaultdict(list)
    with open(os.path.join(data_dir, 'clausie_explicit.txt'),encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            trp = parts[1:]
            clauseie_data[parts[0]].append('\t'.join(trp))
    filt_lines = []
    with open(os.path.join(data_dir, 'openie6_explicit.txt'), encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            index = parts[0]
            trp = parts[1:]
            cl_trp = [x.split('\t') for x in clauseie_data[index]]
            objects = set([trp[-1] for trp in cl_trp])
            if trp[-1] in objects:
                filt_lines.append(line.strip())
    with open(os.path.join(data_dir, 'openie6_object_explicit.txt'),'w', encoding='utf8') as f:
        f.write('\n'.join(filt_lines))

if __name__ == '__main__':
    #filter_m2oie_obj()
    format_benchie_extraction()
    #create_benchie_test()
    #inject_clauseie_elements()
    #create_benchie_test()
    #inject_vps()
    #muti2oie_to_benchie()