import ipdb
import json
import argparse
import random
import pickle
import os

def get_pos(tag):
    res = []
    if 'ARG1' in tag and set(tag[tag.index('ARG1'): len(tag) - tag[::-1].index('ARG1')]) == set(['ARG1']):
        res.append([tag.index('ARG1'), len(tag) - tag[::-1].index('ARG1')-1])
    else:
        res.append([])
    if 'REL' in tag and set(tag[tag.index('REL'): len(tag) - tag[::-1].index('REL')]) == set(['REL']):
        res.append([tag.index('REL'), len(tag) - tag[::-1].index('REL')-1])
    else:
        res.append([])

    args_pos = []
    if 'ARG2' in tag and set(tag[tag.index('ARG2'): len(tag) - tag[::-1].index('ARG2')]) == set(['ARG2']):
        args_pos.append([tag.index('ARG2'), len(tag) - tag[::-1].index('ARG2')-1])
    if 'LOC' in tag and set(tag[tag.index('LOC'): len(tag) - tag[::-1].index('LOC')]) == set(['LOC']):
        args_pos.append([tag.index('LOC'), len(tag) - tag[::-1].index('LOC')-1])
    if 'TIME' in tag and set(tag[tag.index('TIME'): len(tag) - tag[::-1].index('TIME')]) == set(['TIME']):
        args_pos.append([tag.index('TIME'), len(tag) - tag[::-1].index('TIME')-1])
    res.append(args_pos)
    return res

def helper(sent, tag, x):
    res = ''
    for ind in range(len(tag)):
        if tag[ind] == x:
            if sent[ind].strip() != "":
                res += (sent[ind].strip() + ' ')
    res = res.strip()
    return res
def get_tabbed(labels):
    res = []
    for label in labels:
        sent, tag = label.split('|||')
        sent = sent.strip().split()
        tag = tag.strip().split()
        assert len(sent) == len(tag), ipdb.set_trace()
        rel = helper(sent, tag, 'REL')
        if rel == "":
            continue
        arg1 = helper(sent, tag, 'ARG1')
        arg2 = helper(sent, tag, 'ARG2')
        loc = helper(sent, tag, 'LOC')
        time = helper(sent, tag, 'TIME')
        ans = " ".join(sent).strip() + '\t' + rel + '\t' + arg1 
        if arg2 != "":
            ans += '\t' + arg2
        if loc != "":
            ans += '\t' + loc
        if time != "":
            ans += '\t' + time
        ans = ans.strip()
        res.append(ans)
    if len(res) == 0:
        return " ".join(sent)
    return res

def update_data(sentence, labels):
    sent_data = {}
    tuples = []
    sent_data['sentence'] = sentence
    for label in labels:
        sent, tag = label.split('|||')
        sent = sent.strip().split()
        tag = tag.strip().split()
        assert len(sent) == len(tag), ipdb.set_trace()
        tuple_dict = {}
        arguments = get_pos(tag)
        sent_data['sentence'] = " ".join(sent)
        if len(arguments[0]) > 0:
            tuple_dict['arg0_pos'] = arguments[0]
        else:
            tuple_dict['arg0_pos'] = (-1,-1)
        if len(arguments[1]) > 0:
            tuple_dict['rel_pos'] = arguments[1]
        else:
            tuple_dict['rel_pos'] = (-1,-1)
        tuple_dict['args_pos'] = arguments[2]
        tuples.append(tuple_dict)
    sent_data['tuples'] = tuples
    return sent_data

if __name__ == '__main__':
    parser = argparse.ArgumentParser('create data')
    parser.add_argument('--fp1', type=str, help='input file')
    parser.add_argument('--fp2', type=str, help='input file')
    parser.add_argument('--fp3', type=str, help='input file')
    parser.add_argument('--lang', type=str, required=True, help='input file')
    parser.add_argument('--out', type=str, help='output file')
    args = parser.parse_args()
    if args.lang == 'en':
        args.fp1 = f"../data/{args.lang}/train.sentences"
        args.fp2 = f"../data/openie6/train.sentences_labels"
        args.fp3 = f"../data/openie6/train.count_extractions"
        args.out = f"../data/{args.lang}/multi2oie/"
    else:
        args.fp1 = f"../data/{args.lang}/mbart/train.sentences"
        args.fp2 = f"../data/{args.lang}/clp/aligned.extractions"
        args.fp3 = f"../data/openie6/train.count_extractions"
        args.out = f"../data/{args.lang}/multi2oie/"
    os.makedirs(args.out, exist_ok=True)
    data = []
    gold_data = []
    with open(args.fp1, 'r') as f1,\
        open(args.fp2, 'r') as f2,\
            open(args.fp3, 'r') as f3:
            data1 = f1.readlines()
            data2 = f2.readlines()
            data3 = f3.readlines()
            assert len(data1) == len(data3)
            index = 0
            for ind in range(len(data3)):
                count = int(data3[ind].split('\t')[0].strip())
                data.append(update_data(data1[ind].strip(), data2[index:index+count]))
                gold_data.append(get_tabbed(data2[index:index+count]))
                index += count
            assert len(data2) == index, ipdb.set_trace()
    whole_data = []
    for ind in range(len(data)):
        whole_data.append((data[ind], gold_data[ind]))
    random.seed(0)
    random.shuffle(whole_data)
    valid_data, train_data = [], []
    for ind in range(len(whole_data)):
        if ind <1000:
            valid_data.append(whole_data[ind][1])
        else:
            train_data.append(whole_data[ind][0])
    with open(args.out+f'/structured_data_{args.lang}.json', 'w') as outfile:
        json.dump(train_data, outfile)
    dev_sentences = []
    with open(args.out+'/carb_dev.tsv', 'w') as f1:
        for sd in valid_data:
            if type(sd) == type([]):
                sent = sd[0].split('\t')[0].strip()
                dev_sentences.append(sent)
                for s in sd:
                    f1.write(s.strip() + '\n')
            else:
                dev_sentences.append(sd.strip())
    with open(args.out+'/dev_sentences.pkl', 'wb') as f1:
        pickle.dump(dev_sentences, f1)
                
    