#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import os, re
import json, pickle
import numpy as np
from tqdm import tqdm
from collections import defaultdict, OrderedDict
import spacy, operator
from spacy.matcher import Matcher


def readable_spanish():
    file = 'C:\\data\\milie\\Re-OIE2016-Portuguese.json'
    with open(file, encoding='utf8') as f:
        data = json.load(f)
    readable_data = dict()
    for sent in data:
        triples = []
        for ex in data[sent]:
            arg0,arg1,arg2,arg3 = ex['arg0'],ex['arg1'],ex['arg2'],ex['arg3']
            args = [ex['arg1'],ex['arg2'],ex['arg3']]
            if args[0]=='' and (args[1]!='' or args[2]!=''):
                print("Found")
            if args[1]=='' and args[2]!='':
                print("Found")
            pred = ex['pred']
            triple = [arg0,pred,arg1,arg2,arg3]
            triple = "-#-".join([x for x in triple if x!=''])
            triples.append(triple)
        readable_data[sent] = triples
    with open(os.path.join('C:\\data\\milie', 'test_pt_readable.json'), 'w', encoding='utf8') as f:
        json.dump(readable_data, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)


def format_multilingual():
    data_dir = 'C:\\data\\milie\\'
    #data_dir = '/home/bkotnis/local/data/milie/'
    f_name = 're_oie2016_test_portuguese.pkl'
    full_path = os.path.join(data_dir,f_name)
    with open(full_path, 'rb') as f:
        sents = pickle.load(f)
    #verb_phrases, noun_phrases = extract_phrases(sents)
    nlp = spacy.load('pt_core_news_lg', disable=["lemmatizer", 'ner'])
    dep_dict = extract_dep(sents, nlp, 'dep_map_pt.txt')
    seen_sents = set()
    test_data  = []
    for s in sents:
        if s in seen_sents:
            continue
        ex = dict()
        ex['sentence'] = s
        ex['dep'] = dep_dict[s]
        test_data.append(ex)
        seen_sents.add(s)
    with open(os.path.join(data_dir, 'test_pt.json'),'w', encoding='utf8') as f:
        json.dump(test_data, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)



def get_elem_pos(elem):
    if elem=='subject':
        elem_pos = 0
    elif elem=='predicate':
        elem_pos = 1
    elif elem == 'object':
        elem_pos = 2
    return elem_pos

def inject_predictions():
    file_name = 'C:\\code\\clausie\\carb_dev_clausie.txt'
    sent_file = 'C:\\code\\clausie\\CaRB_dev_sents.txt'
    data_dir = 'C:\\data\\milie\\en\\'
    sent_dict = dict()
    #step 1. read from the partial predictions file
    predicate_dict = defaultdict(set)
    with open(file_name, encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            predicate_dict[int(parts[0])].add(parts[2].replace("\"",""))
    with open(sent_file, encoding='utf8') as f:
        for count,line in enumerate(f):
            sent_dict[count+1] = line.strip()
    predicate_dict = {sent_dict[k]:v for k,v in predicate_dict.items()}
    # step 2. read from validation/test data
    with open(os.path.join(data_dir,'dev.json'), encoding='utf8') as f:
        data = json.load(f)
    # step 3 merge
    for ex in data:
        ex['pred_elem'] = list(predicate_dict.get(ex['sentence'],[]))
    #step 4, write file
    with open(os.path.join(data_dir,'dev_pred_verb.json'),'w', encoding='utf8') as f:
        json.dump(data, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)

def format_dev():
    #data_dir = "/home/bkotnis/local/data/milie/en/"
    data_dir = "C:\\data\\milie\\en\\"
    sents, visited = [], []
    with open(os.path.join(data_dir,'CaRB_dev.tsv'), 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            sents.append(parts[0])
    nlp = spacy.load('en_core_web_lg', disable=["lemmatizer", 'ner'])
    dep_dict = extract_dep(sents, nlp, 'dep_map_en.txt')
    dev_data = []
    seen_sents = set()
    for s in sents:
        if s in seen_sents:
            continue
        ex = dict()
        ex['sentence'] = s
        ex['dep'] = dep_dict[s]
        dev_data.append(ex)
        seen_sents.add(s)
    with open(os.path.join(data_dir, 'CaRB_dev_sents.txt'), 'w', encoding='utf8') as f:
        sents = '\n'.join(list(set(sents)))
        f.write(sents)

    with open(os.path.join(data_dir, 'dev.json'),'w', encoding='utf8') as f:
        json.dump(dev_data, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)


def train_statistics():
    #data_dir = "/home/bkotnis/local/data/milie/en/"
    data_dir = "C:\\data\\milie\\en\\"
    with open(os.path.join(data_dir, 'structured_data.json')) as f:
        data = json.load(f)
    length_distrib = defaultdict(int)
    elem_distrib = defaultdict(int)
    for ex in tqdm(data, desc='Analyzing data'):
        length_distrib[len(ex['tuples'])] +=1
        subjs, preds, objs = set(), set(), set()
        for trp in ex['tuples']:
            s, v = trp['arg0'], trp['relation']
            if len(trp['args']) <= 0:
                continue
            obj = trp['args'][0]
            subjs.add(s)
            preds.add(v)
            objs.add(obj)
        mfe = np.argmax([len(subjs), len(preds), len(objs)])
        elem_distrib[mfe] +=1
    sorted_lens = list(sorted(length_distrib.items(), key=operator.itemgetter(1), reverse=True))
    print(elem_distrib)


def format_train():
    data_dir = "/home/bkotnis/local/data/milie/en/"
    #data_dir = "C:\\data\\milie\\en\\"
    with open(os.path.join(data_dir, 'structured_data.json')) as f:
        data = json.load(f)
    all_positives, negatives, marked_examples = [], [], []
    for ex in tqdm(data, desc='Formatting positives and negatives'):
        new_ex = dict()
        if ex['sentence'] is None or len(ex['sentence'])<=0:
            continue
        new_ex['sentence'] = ex['sentence']
        triples = []
        for trp in ex['tuples']:
            s,v = trp['arg0'], trp['relation']
            if s not in ex['sentence'] or v not in ex['sentence'] or len(trp['args'])<=0:
                continue
            obj = trp['args'][0]
            args = []
            if len(trp['args'])>1:
                args =  trp['args'][1:]
            triple = (s,v,obj, args)
            triples.append(triple)

        if len(triples)<=0:
            continue
        new_ex['triples'] = triples
        new_ex['subject'], new_ex['predicate'], new_ex['object'] = [], [], []
        new_ex_cp, new_ex_cp_2 = new_ex.copy(), new_ex.copy()
        marked_examples.append(new_ex_cp_2)
        positives = gen_rand_positives(new_ex)
        all_positives.extend(positives)
        negs = gen_rand_negatives(new_ex_cp)
        negatives.extend(negs)
    np.random.shuffle(negatives)
    num_negs = 1000000
    negatives = negatives[:num_negs]
    #cls_examples = create_cls_data(marked_examples, num_cls)

    if os.path.exists(os.path.join(data_dir,'dependencies.pkl')):
        with open(os.path.join(data_dir,'dependencies.pkl'),'rb') as f:
            dep_tags = pickle.load(f)
    else:
        sentences = list(set([ex['sentence'] for ex in all_positives if len(ex['sentence'])>1]))
        nlp = spacy.load('en_core_web_lg', disable=[ "lemmatizer", 'ner'])
        dep_tags = extract_dep(sentences, nlp, 'dep_map_en.txt')
        with open(os.path.join(data_dir,'dependencies.pkl'),'wb') as f:
            dep_tags = pickle.dump(dep_tags, f)

    for ex in all_positives:
        ex['dep'] = dep_tags[ex['sentence']]
    for ex in negatives:
        ex['dep'] = dep_tags[ex['sentence']]
    #for ex in cls_examples:
        #ex['dep'] = dep_tags[ex['sentence']]

    all_positives = mark_sentences(all_positives)
    negatives = mark_sentences(negatives)
    frmt_data = all_positives + negatives# + cls_examples
    np.random.shuffle(frmt_data)
    print(f"Total Data {len(frmt_data)}, Num. Negatives {len(negatives)},  Num. Positives {len(all_positives)}")
    with open(os.path.join(data_dir, f'train_dep_cls_{str(num_negs)}.json'),'w', encoding='utf8') as f:
        json.dump(frmt_data,f, indent=None, separators=(', \n',': '), ensure_ascii=False)


def create_cls_data(data, num_cls):
    np.random.shuffle(data)
    data = data[:num_cls]
    positives, negatives = [],[]
    for ex in tqdm(data, desc='Creating CLS data'):
        pos = mark_all(ex, 1)
        positives.extend(pos)
        ex['triples'] = [x[:3] for x in ex["triples"]]
        ex = randomize_subject(ex)
        ex = randomize_object(ex)
        ex = randomize_predicate(ex)
        negs = mark_all(ex, 0)
        for nex in negs:
            subj, pred, obj = remove_overlap(nex)
            nex['subject'], nex['predicate'], nex['object'] = subj, pred, obj
        negatives.extend(negs)
    data = positives + negatives
    np.random.shuffle(data)
    return data

def mark_sentences(data):
    for ex in tqdm(data, 'Marking Sentences'):
        sent = ex['sentence']
        if len(ex['subject'])>0:
            sent = mark_sentence(sent,ex['subject'][0],'<A0>')
        if len(ex['predicate'])>0:
            sent = mark_sentence(sent,ex['predicate'][0],'<P>')
        if len(ex['object'])>0:
            sent = mark_sentence(sent,ex['object'],'<A1>')
        ex['sentence'] = sent
    return data

def extract_dep(text,nlp,dep_map_file):
    dep_map= dict()
    with open(dep_map_file, encoding='utf8') as f:
        for line in f:
            key,val = line.split('-')
            dep_map[key.strip()] = val.strip()
    def _extract_dep_tags(nlp, sentences):
        for doc in nlp.pipe(sentences, batch_size=10000):
            dep = [(token.text, dep_map[token.dep_]) for token in doc if len(token.dep_)>0]
            yield dep
    dep_tags = _extract_dep_tags(nlp, text)
    tag_map = accumulate_tags(dep_tags, text)
    return tag_map

def extract_pos(text, nlp):
    pos_order = ['X', 'DET', 'INTJ', 'PUNCT', 'NUM', 'SYM', 'PART', 'SCONJ', 'CONJ',
                      'PRON', 'ADP', 'ADV', 'ADP', 'AUX', 'VERB', 'NOUN', 'PROPN']
    def _ext_pos_tags(nlp, sentences):
        for doc in nlp.pipe(sentences, batch_size=10000):
            pos = [(token.text, token.pos_) for token in doc]
            yield pos
    print(f"Num. Sentences {len(text)}")
    pos_tags = _ext_pos_tags(nlp, text)
    tag_map = accumulate_tags(pos_tags, text, pos_order)
    return tag_map

def accumulate_tags(ext_tags, text, order=None):
    all_tags = []
    for tags in tqdm(ext_tags, total=len(text),desc='Extracting Tags'):
        tag_dict = defaultdict(set)
        for tok,tag in tags:
            tag_dict[tag].add(tok)
        tag_list = []
        if order is None:
            order = tag_dict.keys()
        for key in order:
            if key in tag_dict:
                tag_list.append([key]+list(tag_dict[key]))
        all_tags.append(tag_list)
    assert len(all_tags) == len(text)
    pos_map = {text[count]: all_tags[count] for count in range(len(text))}
    return pos_map


def gen_rand_positives(ex):

    data = []
    # Case 1 nothing marked in the sentence
    choice_1= np.random.choice(list(range(3)),p=[2/6, 2.5/6, 2.5/6])
    #choice_1 = 2
    if choice_1 ==0:
        ex1 = mark_level1(ex,ex['triples'], 'subject')
    elif choice_1==1:
        ex1 = mark_level1(ex,ex['triples'],'object')
    elif choice_1==2:
        ex1 = mark_level1(ex,ex['triples'], 'predicate')
    else:
        raise RuntimeError()
    data.append(ex1)

    #choice_2 = np.random.randint(0,5)
    choice_2 = np.random.choice(list(range(6)), p=[3/12, 1/12, 2/12, 1/12, 2/12, 3/12])
    # case 1 subj, pred obj
    if choice_2==0:
        ex2 = mark_level2(ex, ex['triples'], 'subject', 'object')
    # case 2 subj, mark pred
    elif choice_2==1:
        ex2 = mark_level2(ex,ex['triples'], 'subject', 'predicate')
    # case 3 obj, mark subj
    elif choice_2==2:
        ex2 = mark_level2(ex,ex['triples'], 'object', 'subject')
    # case 4 obj, mark pred
    elif choice_2==3:
        ex2 = mark_level2(ex, ex['triples'],'object', 'predicate')
    # case 5 pred, mark subj
    elif choice_2==4:
        ex2 = mark_level2(ex, ex['triples'],'predicate', 'subject')
    elif choice_2==5:
        ex2 = mark_level2(ex,ex['triples'], 'predicate', 'object')
    else:
        raise RuntimeError()

    data.extend(ex2)

    choice_3 = np.random.choice(list(range(6,12)), p=[1/12, 3/12, 2/12, 1/12, 3/12, 2/12])
    #choice_3 = np.random.randint(6, 12)
    # case 7, subj, obj mark pred
    if choice_3==6:
        ex3 = mark_level3(ex, 'subject', 'object', 'predicate')
    # case 8, subj, pred mark obj
    elif choice_3==7:
        ex3 = mark_level3(ex, 'subject', 'predicate', 'object')
    # case 9, obj, pred mark subj
    elif choice_3==8:
        ex3 = mark_level3(ex, 'object', 'predicate', 'subject')
    elif choice_3==9:
        ex3 = mark_level3(ex, 'object', 'subject', 'predicate')
    elif choice_3==10:
        ex3 = mark_level3(ex, 'predicate', 'subject', 'object')
    elif choice_3==11:
        ex3 = mark_level3(ex, 'predicate', 'object', 'subject')
    else:
        raise RuntimeError()
    data.extend(ex3)

    triples = [trp for trp in ex['triples'] if len(trp[-1])>0]
    if len(triples)>0:
        data.extend(add_arguments(ex, triples))
    elif np.random.random()<=0.1:
        data.extend(add_arguments(ex, ex['triples']))

    return data

def add_arguments(ex, triples):
    data = []
    for trp in triples:
        s,r,o,args = trp
        new_ex = ex.copy()
        new_ex['subject'], new_ex['object'], new_ex['predicate'] = [s],[o],[r]
        new_ex['targets'] = args
        new_ex['head'] = 'arguments'
        new_ex['triples'] = None
        data.append(new_ex)
    return data

def mark_all(ex, cls_token):
    new_examples = []
    for trp in ex['triples']:
        new_ex = dict()
        new_ex['sentence'] = ex['sentence']
        if len(trp) <= 3:
            subject, predicate, object = trp
        else:
            subject, predicate, object, _ = trp
        new_ex['subject'], new_ex['object'], new_ex['predicate'] = [subject], [object], [predicate]
        new_ex['head'] = 'CLS'
        new_ex['targets'] = []
        new_ex['cls_token'] = cls_token
        new_examples.append(new_ex)
    return new_examples

def mark_level3(ex, mark1, mark2, target):
    new_examples = defaultdict(list)
    for trp in ex['triples']:
        ex['subject'], ex['object'], ex['predicate'] = [], [], []
        if len(trp)<=3:
            subject, predicate, object = trp
        else:
            subject, predicate, object, _ = trp
        new_ex_arr = mark_level2(ex, [trp], mark2, target)
        for new_ex in new_ex_arr:
            if mark1 == 'subject':
                new_ex['subject'].append(subject)
                key1 = subject
            elif mark1 == 'object':
                new_ex['object'] = object
                key1 = object
            elif mark1 == 'predicate':
                new_ex['predicate'].append(predicate)
                key1 = predicate
            else:
                raise RuntimeError(f'{mark1} is not allowed.')
            if mark2=='subject':
                key2 = subject
            elif mark2=='object':
                key2 = object
            elif mark2=='predicate':
                key2 = predicate
            else:
                raise RuntimeError(f'{mark2} is not allowed.')
            new_examples[(key1,key2)].append(new_ex)

    aggregated_examples = aggregate_examples(new_examples)
    return aggregated_examples


def mark_level2(ex, triples,mark, target):
    seen_tuples = set()
    new_examples = defaultdict(list)
    for trp in triples:
        if len(trp) <= 3:
            subject, predicate, object = trp
        else:
            subject, predicate, object, _ = trp
        ex['subject'], ex['object'], ex['predicate'] = [],[],[]
        if mark=='subject':
            ex['subject'].append(subject)
            key = subject
        elif mark =='object':
            ex['object'].append(object)
            key = object
        elif mark == 'predicate':
            ex['predicate'].append(predicate)
            key = predicate
        else:
            raise RuntimeError()
        new_ex = mark_level1(ex,[trp], target)
        pair = (new_ex[mark][0], new_ex['targets'][0])
        if pair in seen_tuples:
            continue
        new_examples[key].append(new_ex)
        seen_tuples.add(pair)
    #merge targets
    aggregated_examples = aggregate_examples(new_examples)
    return aggregated_examples


def mark_level1(ex,triples,head):
    if head == 'subject':
        pos = 0
    elif head == 'object':
        pos=2
    elif head == 'predicate':
        pos=1
    else:
        raise RuntimeError()
    new_ex = ex.copy()
    new_ex['triples'] = None
    new_ex['targets'] = [x[pos] for x in triples]
    new_ex['head'] = head
    #new_ex['cls_token'] = -1
    return new_ex

def aggregate_examples(new_examples):
    aggregated_examples = []
    for key in new_examples:
        if len(new_examples[key])>1:
            #assert np.sum([len(ex['targets']) for ex in new_examples[key]]) <= len(new_examples[key])
            #print(new_examples)
            targets = [x for ex in new_examples[key] for x in ex['targets'] ]
            new_examples[key][0]['targets'] = targets
        aggregated_examples.append(new_examples[key][0])
    return aggregated_examples

def invert_triples(ex):
    inverted = [(trp[2],trp[1], trp[0],[]) for trp in ex['triples']]
    inverted = [(trp[2],trp[1], trp[0]) for trp in ex['triples']]
    ex['triples'] = inverted
    return ex

#def randomize_predicate(ex):
    #corrupted = [(trp[0],pick_predicate(ex['sentence']), trp[2],[]) for trp in ex['triples']]
    #corrupted = [(trp[0],pick_predicate(ex['sentence']), trp[2]) for trp in ex['triples']]
    #ex['triples'] = corrupted
    #return ex
def randomize_predicate(ex):
    corrupted = [(trp[0],pick_predicate(ex['sentence']), trp[2]) for trp in ex['triples']]
    ex['triples'] = corrupted
    return ex

def pick_predicate(sent):
    tokens = sent.split(' ')
    # 1 choose length
    pred_len = np.random.randint(1,min(6,len(tokens)))
    pred_start = np.random.randint(0, len(tokens)-pred_len)
    pred = ' '.join(tokens[pred_start:pred_start+pred_len])
    return pred

def gen_corrupt_negatives(ex):
    if len(ex['triples']) > 1:
        choice = np.random.choice(list(range(0, 6)), p=[1 / 12, 3 / 12, 2 / 12, 1 / 12, 3 / 12, 2 / 12])
    else:
        choice = np.random.choice(list(range(0, 3)), p=[1 / 6, 2 / 3, 1 / 6])
        # case 1 incorrect subject
    #choice = 2
    if choice == 0:
        ex = invert_triples(ex)
        new_ex = mark_level2(ex, ex['triples'], 'subject', 'object')
        np.random.shuffle(new_ex)
        new_ex = new_ex[0]
    # case 2 incorrect object
    elif choice == 1:
        ex = invert_triples(ex)
        new_ex = mark_level2(ex, ex['triples'], 'object', 'predicate')
        np.random.shuffle(new_ex)
        new_ex = new_ex[0]
    # case 3 incorrect predicate
    elif choice == 2:
        ex = randomize_predicate(ex)
        new_ex = mark_level2(ex, ex['triples'], 'predicate', 'subject')
        np.random.shuffle(new_ex)
        new_ex = new_ex[0]
    # case 4 incorrect subject object
    elif choice == 3:
        new_ex = corrupt_triples(ex, 0, 2, 'predicate')
    # case 5 incorrect subject predicate
    elif choice == 4:
        new_ex = corrupt_triples(ex, 0, 1, 'object')
    # case 6 incorrect object predicate
    elif choice == 5:
        new_ex = corrupt_triples(ex, 1, 2, 'subject')
    else:
        raise RuntimeError()
    if new_ex is None:
        return []
    # new_ex['cls_token'] = 0
    new_ex['targets'] = ['']
    return [new_ex]


def corrupt_triples(ex, arg0_ind, arg1_ind, target_head):
    def get_name(arg_id):
        if arg_id == 0:
            return 'subject'
        if arg_id == 1:
            return 'predicate'
        if arg_id == 2:
            return 'object'

    # sentence = ex['sentence']
    pos_pairs = set([(x[arg0_ind], x[arg1_ind]) for x in ex['triples']])
    np.random.shuffle(ex['triples'])
    arg0 = ex['triples'][0][arg0_ind]
    remaining = ex['triples'][1:]
    if len(remaining) > 1:
        np.random.shuffle(remaining)
    arg1 = remaining[0][arg1_ind]
    if (arg0, arg1) in pos_pairs:
        temp = arg0
        arg0 = arg1
        arg1 = temp
    # if arg0 in arg1:
    # arg0 = pick_predicate(sentence)
    new_ex = ex.copy()
    new_ex['subject'], new_ex['object'], new_ex['predicate'] = [], [], []
    if isinstance(arg0, tuple):
        arg0 = list(arg0)
    else:
        arg0 = [arg0]
    if isinstance(arg1, tuple):
        arg1 = list(arg1)
    else:
        arg1 = [arg1]
    new_ex[get_name(arg0_ind)] = arg0
    new_ex[get_name(arg1_ind)] = arg1
    new_ex['targets'] = ['']
    new_ex['head'] = target_head
    new_ex['triples'] = None
    # new_ex['cls_token'] = 0
    return new_ex


def gen_rand_negatives(ex):
    ex['triples'] = [trp[:3] for trp in ex['triples']]
    choice = np.random.choice(list(range(0, 6)), p=[1/12, 3/12, 2/12, 2/12, 2/12, 2/12])
    #case 1 incorrect subject
    if choice==0:
        ex = randomize_subject(ex)
        new_ex = mark_level2(ex,ex['triples'],'subject','object')
        np.random.shuffle(new_ex)
    #case 2 incorrect object
    elif choice==1:
        ex = randomize_object(ex)
        new_ex = mark_level2(ex,ex['triples'],'object','predicate')
        np.random.shuffle(new_ex)
    #case 3 incorrect predicate
    elif choice==2:
        ex = randomize_predicate(ex)
        new_ex = mark_level2(ex, ex['triples'],'predicate', 'subject')
        np.random.shuffle(new_ex)
    elif choice == 3:
        ex = randomize_subject(ex)
        ex = randomize_object(ex)
        new_ex = mark_level3(ex, 'subject', 'object', 'predicate')
        np.random.shuffle(new_ex)
    elif choice == 4:
        ex = randomize_subject(ex)
        ex = randomize_predicate(ex)
        new_ex = mark_level3(ex, 'subject', 'predicate', 'object')
        np.random.shuffle(new_ex)
    elif choice == 5:
        ex = randomize_predicate(ex)
        ex = randomize_object(ex)
        new_ex = mark_level3(ex, 'predicate', 'object', 'subject')
        np.random.shuffle(new_ex)
    else:
        raise RuntimeError()
    if new_ex is None:
        return []
    #new_ex['cls_token'] = 0
    for x in new_ex:
        x['targets'] = ''
        subj, pred, obj = remove_overlap(x)
        x['subject'], x['predicate'], x['object'] = subj, pred, obj
    return new_ex

def remove_overlap(ex):
    def remove_tokens(elem, history):
        elem_tokens = elem.split(" ")
        new_elem = []
        for tok in elem_tokens:
            if tok not in history:
                new_elem.append(tok)
            history.add(tok)
        if len(new_elem)<=0:
            while True:
                rand_tok = pick_random(ex['sentence'], 1, 0, len(ex['sentence'].split(' ')))
                if rand_tok not in history:
                    new_elem = [rand_tok]
                    history.add(rand_tok)
                    break
        new_elem = ' '.join(new_elem)
        return new_elem, history
    subjects, predicates, objects = ex['subject'], ex['predicate'], ex['object']
    history = set()
    if len(subjects)>0:
        history.update(set(subjects[0].split(" ")))
    if len(predicates)>0:
        pred, history = remove_tokens(predicates[0], history)
        predicates = [pred]
    if len(objects) > 0:
        obj, history = remove_tokens(objects[0], history)
        objects = [obj]
    return subjects, predicates, objects

def randomize_predicate(ex):
    corrupted = [(trp[0],pick_random(ex['sentence'], len(trp[1].split(' ')) , *get_range(ex['sentence'],trp[1]) ), trp[2] ) for trp in ex['triples']]
    corrupted = [x for x in corrupted if x not in set(ex['triples'])]
    ex['triples'] = corrupted
    return ex

def randomize_subject(ex):
    corrupted = [(pick_random(ex['sentence'], len(trp[0].split(' ')), *get_range(ex['sentence'],trp[0]) ),trp[1], trp[2]) for trp in ex['triples']]
    corrupted = [x for x in corrupted if x not in set(ex['triples'])]
    ex['triples'] = corrupted
    return ex

def randomize_object(ex):
    corrupted = [(trp[0],trp[1], pick_random(ex['sentence'], len(trp[2].split(' ')), *get_range(ex['sentence'],trp[2]))) for trp in ex['triples']]
    corrupted = [x for x in corrupted if x not in set(ex['triples'])]
    ex['triples'] = corrupted
    return ex

def pick_random(sent, pred_len, lower, upper):
    tokens = sent.split(' ')
    # 1 choose length
    pred_len = min(np.random.randint(max(1,pred_len-3),pred_len+3),5)
    pred_start = np.random.randint(lower, upper)
    pred = ' '.join(tokens[pred_start:pred_start+pred_len])
    return pred

def get_range(sent, elem):
    tokens = sent.split(' ')
    elem_tokens = elem.split(' ')
    for count,tok in enumerate(tokens):
        if tok==elem_tokens[0]:
            found = True
            for i,e_tok in enumerate(elem_tokens):
                if e_tok!=tokens[count+i]:
                    found = False
                    break
            if found:
                return max(0,count-7), min(count + len(elem_tokens) + 7,len(tokens))
    return 0, len(tokens)

def corrupt_triples(ex,arg0_ind, arg1_ind, target_head):
    def get_name(arg_id):
        if arg_id == 0:
            return 'subject'
        if arg_id == 1:
            return 'predicate'
        if arg_id == 2:
            return 'object'

    #sentence = ex['sentence']
    pos_pairs = set([(x[arg0_ind],x[arg1_ind]) for x in ex['triples']])
    np.random.shuffle(ex['triples'])
    arg0 = ex['triples'][0][arg0_ind]
    remaining = ex['triples'][1:]
    if len(remaining)>1:
        np.random.shuffle(remaining)
    arg1 = remaining[0][arg1_ind]
    if (arg0, arg1) in pos_pairs:
        temp = arg0
        arg0 = arg1
        arg1 = temp
    #if arg0 in arg1:
        #arg0 = pick_predicate(sentence)
    new_ex = ex.copy()
    new_ex['subject'], new_ex['object'], new_ex['predicate'] = [],[],[]
    if isinstance(arg0, tuple):
        arg0 = list(arg0)
    else:
        arg0 = [arg0]
    if isinstance(arg1, tuple):
        arg1 = list(arg1)
    else:
        arg1 = [arg1]
    new_ex[get_name(arg0_ind)] = arg0
    new_ex[get_name(arg1_ind)] = arg1
    new_ex['targets'] = ['']
    new_ex['head'] = target_head
    new_ex['triples'] = None
    #new_ex['cls_token'] = 0
    return new_ex


def extract_phrases(text):

    patterns = [[{'POS':'VERB'}],[{"POS":"ADV"}, {"POS":"VERB"}],[{"POS":"VERB"}, {"POS":"ADV"}],[{"POS":"VERB"}, {"POS":"ADP"}],
                [{'POS':'AUX'}],[{"POS":"ADV"}, {"POS":"AUX"}],[{"POS":"AUX"}, {"POS":"ADV"}],[{"POS":"AUX"}, {"POS":"ADP"}],
                [{'POS':'AUX'},{'POS':'VERB'}]]

    def _ext_verb_phrases(nlp, matcher, sentences):
        for doc in nlp.pipe(sentences, batch_size=10000):
            matches = matcher(doc)
            verb_phrases = [doc[start:end].text for _,start,end in matches]
            nps = [np.text for np in doc.noun_chunks]
            yield verb_phrases, nps

    nlp = spacy.load('es_core_news_sm', disable=["lemmatizer", 'ner'])
    matcher = Matcher(nlp.vocab)
    for count,pat in enumerate(patterns):
        matcher.add(str(count),[pat])
    print(f"Num. Sentences {len(text)}")
    vps = _ext_verb_phrases(nlp, matcher, text)
    verb_phrases, noun_phrases = [],[]
    for phrases in tqdm(vps, desc='Extracting VPs and NPs', total=len(text)):
        vp,np = phrases
        verb_phrases.append(list(set(vp)))
        noun_phrases.append(list(set(np)))
    assert len(verb_phrases) == len(text)
    vp_map = {text[count]: verb_phrases[count] for count in range(len(text))}
    np_map = {text[count]: noun_phrases[count] for count in range(len(text))}
    return vp_map, np_map

def invert_triples(ex):
    inverted = [(trp[2],trp[1], trp[0]) for trp in ex['triples']]
    ex['triples'] = inverted
    return ex

def get_marker(index):
    if index ==0:
        return '<A0>'
    elif index == 1:
        return '<P>'
    return '<A1>'


def mark_sentence(sent, ent,marker):
    if isinstance(ent,list) or isinstance(ent, tuple):
        for e in ent:
            if sent is None or e not in sent:
                return ''
            str_ind = sent.find(e)
            new_sent = sent[:str_ind] + f' {marker} ' + e + f' {marker} ' + sent[str_ind + len(e):]
            sent = new_sent
        return sent
    if sent is None or ent not in sent:
        return ''
    str_ind = sent.find(ent)
    new_sent = sent[:str_ind] + f' {marker} ' + ent + f' {marker} ' + sent[str_ind + len(ent):]
    return new_sent

def mark_sentence_regex(sent, ent, marker):
    if sent is None:
        return ''
    #print(sent)
    ent = ent.replace("$","\$").replace('\\','')
    indices = [m.start() for m in re.finditer(f'{ent}', sent)]
    if len(indices)<=0:
        return None
    str_ind = indices[0]
    new_sent = sent[:str_ind] + f' {marker} ' + ent + f' {marker} ' + sent[str_ind+len(ent):]
    return new_sent

def analyze_train_split():
    train_file = 'C:\\data\\milie\\en\\structured_data.json'
    max_args = 0
    num_warning = 0
    with open(train_file, encoding='utf8') as f:
        data = json.load(f)
    for ex in tqdm(data):
        if ex['sentence'] is None or len(ex['sentence'])<=0:
            continue
        duplicates = set()
        duplicate_args = set()
        for trp in ex['tuples']:
            s ,v = trp['arg0'], trp['relation']
            args = trp['args']
            arg_hash = tuple(args)
            if len(args)>max_args:
                max_args=len(args)
            hash = (s,v)
            if hash not in duplicates:
                duplicates.add(hash)
            else:
                if arg_hash in duplicate_args:
                    print("Duplicate triples")
                else:
                    num_warning+=1
            if arg_hash not in duplicate_args:
                duplicate_args.add(arg_hash)
    print(f"Maximum Args: {max_args}, Num Warnings {num_warning}")

def analyze_valid_split():
    valid_file = 'C:\\data\\milie\\en\\CaRB_dev.tsv'
    duplicates = dict()
    with open(valid_file, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts)<4:
                sent, p, s = parts[0], parts[1], parts[2]
                objects = []
            else:
                sent, p, s, objects = parts[0], parts[1], parts[2], parts[3:]
            if len(parts)>5:
                print('warn')
            hash = (sent, s, p)
            if hash not in duplicates:
                duplicates[hash] = objects
            else:
                print("Warning")

def format_meta_training_data():
    #data_dir = '/home/bkotnis/local/data/milie/en/'
    data_dir = 'C:\\data\\milie\\en\\'
    file_name = 'train_meta.pkl'
    with open(os.path.join(data_dir, file_name),'rb') as f:
        raw_data  = pickle.load(f)
    data,label_set = [], set()
    for ex in tqdm(raw_data):
        scores = ex['labels']
        if len(scores)<=1:
            label = 2
        else:
            label = np.argmax(scores)
        label_set.add(int(label))
        ex['label'] = int(label)
        ex['labels'] = None
        data.append(ex)
    np.random.shuffle(data)
    dev = data[:2000]
    train = data[2000:]
    with open(os.path.join(data_dir, 'train_meta.json'),'w', encoding='utf8') as f:
        json.dump(train, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)
    with open(os.path.join(data_dir, 'dev_meta.json'),'w', encoding='utf8') as f:
        json.dump(dev, f, indent=None, separators=(', \n', ': '), ensure_ascii=False)


def create_translation_splits():
    data_dir = 'C:\\data\\milie\\en\\'
    write_data_dir = 'C:\\data\\milie\\spanish_splits\\'
    file_name = 'CaRB_test.tsv'
    data = defaultdict(list)
    with open(os.path.join(data_dir,file_name), encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            sent = parts[0]
            triple = parts[1:]
            data[sent].append(triple)
    header_line = "Num.\tSentence\tPredicate\tSubject\tObject\tArgument 1\tArgument 2\tArgument 3\tArgument 4"
    write_buff, sample_num = [header_line], 0
    for count,sent in enumerate(data):
        for trp in data[sent]:
            trp_line = '\t'.join(trp)
            write_line = f"{count}\t{sent}\t" + trp_line
            write_buff.append(write_line)
        if count>0 and count%100==0:
            with open(os.path.join(write_data_dir,f'CaRB_test_sample_{sample_num}.tsv'),'w', encoding='utf8') as f:
                f.write('\n'.join(write_buff))
            write_buff = [header_line]
            sample_num += 1
    with open(os.path.join(write_data_dir, f'CaRB_test_sample_{sample_num}.tsv'), 'w', encoding='utf8') as f:
        f.write('\n'.join(write_buff))

if __name__ == '__main__':
    #create_translation_splits()
    #inject_predictions()
    #train_statistics()
    format_train()
    #format_dev()
    #readable_spanish()
    #format_multilingual()
    #analyze_valid_split()
    #inject_preds()
    #process_pt_rel_extraction()
    #create_benchie_test()
    #inject_preds()
    #format_benchie_extraction()
    #format_meta_training_data()