import json
import random
from tqdm import tqdm

import sys
sys.path.append("..")


def extract_data(input_paths, output_path=None):
    data = []
    test_data = []
    with open(input_paths[0], 'r', encoding='utf8') as f:
        for line in f:
            data.append(line)
    with open(input_paths[1], 'r', encoding='utf8') as f:
        for line in f:
            data.append(line)
    with open(input_paths[2], 'r', encoding='utf8') as f:
        for line in f:
            test_data.append(line)
    data = [json.loads(d) for d in data]
    test_data = [json.loads(d) for d in test_data]
    arguments = {}
    for d in tqdm(data):
        entites = {}
        for entity in d['entities']:
            entites[entity['id']] = entity["mention"]
        arguments[d['id']] = {"document": d["document"]}
        for event in d['events']:
            arg = []
            for argument in event['argument'].keys():
                arg_mentions = []
                for mention in event['argument'][argument]:
                    if "content" in mention:
                        arg_mentions.append(
                            {
                                "mention": mention["content"],
                                "offset": mention["offset"]
                            }
                        )
                    else:
                        for ent in entites[mention['entity_id']]:
                            arg_mentions.append(
                                {
                                    "mention": ent["mention"],
                                    "offset": ent["offset"]
                                }
                            )
                if arg_mentions != []:
                    arg.append({
                        "type": argument,
                        "mentions": arg_mentions
                    })
            arg_set = {json.dumps(aa, sort_keys=True) for aa in arg}
            arg = [json.loads(t) for t in arg_set]

            arguments[d['id']][event['id']] = arg
    for d in tqdm(test_data):
        arguments[d['id']] = {"document": d["text"]}
        for event in d['events']:
            arg = []
            for trigger in event['triggers']:
                for argument in trigger["arguments"]:
                    arg_type = argument["role"].split(".")[1]
                    arg_mentions = []
                    for mention in argument["mentions"]:
                        arg_mentions.append(
                            {
                                "mention": mention["mention"],
                                "offset": mention["position"]
                            }
                        )
                    arg.append({
                        "type": arg_type,
                        "mentions": arg_mentions
                    })
            arg_set = {json.dumps(aa, sort_keys=True) for aa in arg}
            arg = [json.loads(t) for t in arg_set]

            arguments[d['id']][event['id']] = arg
    if output_path:
        with open(output_path, 'w', encoding='utf8') as f:
            json.dump(arguments, f, ensure_ascii=False)
    
    
def get_sample_data(input_path, output_path):
    with open(input_path, 'r', encoding='utf8') as f:
        data = json.loads(f.readline())
    with open(output_path, 'w', encoding='utf8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)


def merge_data(input_paths, output_paths):
    train_data = []
    with open(input_paths[0], 'r', encoding='utf8') as f:
        for line in f:
            train_data.append(line)
    train_data = [json.loads(d) for d in train_data]
    valid_data = []
    with open(input_paths[1], 'r', encoding='utf8') as f:
        for line in f:
            valid_data.append(line)
    valid_data = [json.loads(d) for d in valid_data]
    test_data = []
    with open(input_paths[2], 'r', encoding='utf8') as f:
        for line in f:
            test_data.append(line)
    test_data = [json.loads(d) for d in test_data]
    with open(input_paths[3], 'r', encoding='utf8') as f:
        arguments = json.load(f)
    for d in tqdm(train_data):
        d['has_arguments'] = False
        if d['id'] in arguments:
            d['document'] = arguments[d['id']]['document']
            d['has_arguments'] = True
            for event in d['events']:
                if event['id'] in arguments[d['id']]:
                    event['arguments'] = arguments[d['id']][event['id']]
                else:
                    event['arguments'] = []
        else:
            for event in d['events']:
                event['arguments'] = []
    with open(output_paths[0], 'w', encoding='utf8') as f:
        for d in train_data:
            f.write(json.dumps(d, ensure_ascii=False) + '\n')
    for d in tqdm(valid_data):
        d['has_arguments'] = False
        if d['id'] in arguments:
            d['document'] = arguments[d['id']]['document']
            d['has_arguments'] = True
            for event in d['events']:
                if event['id'] in arguments[d['id']]:
                    event['arguments'] = arguments[d['id']][event['id']]
                else:
                    event['arguments'] = []
        else:
            for event in d['events']:
                event['arguments'] = []
    with open(output_paths[1], 'w', encoding='utf8') as f:
        for d in valid_data:
            f.write(json.dumps(d, ensure_ascii=False) + '\n')
    for d in tqdm(test_data):
        d['has_arguments'] = False
        if d['id'] in arguments:
            d['document'] = arguments[d['id']]['document']
            d['has_arguments'] = True
            for event in d['events']:
                if event['id'] in arguments[d['id']]:
                    event['arguments'] = arguments[d['id']][event['id']]
                else:
                    event['arguments'] = []
        else:
            for event in d['events']:
                event['arguments'] = []
    with open(output_paths[2], 'w', encoding='utf8') as f:
        for d in test_data:
            f.write(json.dumps(d, ensure_ascii=False) + '\n')

def get_evidence_data(input_paths, output_paths):
    train_data = []
    with open(input_paths[0], 'r', encoding='utf8') as f:
        train_data = f.readlines()
    train_data = [json.loads(d) for d in train_data]
    valid_data = []
    with open(input_paths[1], 'r', encoding='utf8') as f:
        valid_data = f.readlines()
    valid_data = [json.loads(d) for d in valid_data]
    test_data = []
    with open(input_paths[2], 'r', encoding='utf8') as f:
        test_data = f.readlines()
    test_data = [json.loads(d) for d in test_data]
    new_train_data = []
    for d in tqdm(train_data):
        new_events = []
        for event in d['events']:
            new_mentions = []
            for mention in event['mention']:
                if mention['factuality'] == "CT+" or mention['factuality'] == "Uu":
                    continue
                if mention["evidence_word"] != None:
                    align = True
                    for offset in mention["evidence_offset"]:
                        if mention['sent_id'] != offset[0]:
                            align = False
                            break
                    if align:
                        new_mentions.append(mention)
            if new_mentions != []:
                event['mention'] = new_mentions
                new_events.append(event)
        if new_events != []:
            d['events'] = new_events
            new_train_data.append(d)
    with open(output_paths[0], 'w', encoding='utf8') as f:
        for d in new_train_data:
            f.write(json.dumps(d, ensure_ascii=False) + '\n')
    new_valid_data = []
    for d in tqdm(valid_data):
        new_events = []
        for event in d['events']:
            new_mentions = []
            for mention in event['mention']:
                if mention['factuality'] == "CT+" or mention['factuality'] == "Uu":
                    continue
                if mention["evidence_word"] != None:
                    align = True
                    for offset in mention["evidence_offset"]:
                        if mention['sent_id'] != offset[0]:
                            align = False
                            break
                    if align:
                        new_mentions.append(mention)
            if new_mentions != []:
                event['mention'] = new_mentions
                new_events.append(event)
        if new_events != []:
            d['events'] = new_events
            new_valid_data.append(d)
    with open(output_paths[1], 'w', encoding='utf8') as f:
        for d in new_valid_data:
            f.write(json.dumps(d, ensure_ascii=False) + '\n')
    new_test_data = []
    for d in tqdm(test_data):
        new_events = []
        for event in d['events']:
            new_mentions = []
            for mention in event['mention']:
                if mention['factuality'] == "CT+" or mention['factuality'] == "Uu":
                    continue
                if mention["evidence_word"] != None:
                    align = True
                    for offset in mention["evidence_offset"]:
                        if mention['sent_id'] != offset[0]:
                            align = False
                            break
                    if align:
                        new_mentions.append(mention)
            if new_mentions != []:
                event['mention'] = new_mentions
                new_events.append(event)
        if new_events != []:
            d['events'] = new_events
            new_test_data.append(d)
    with open(output_paths[2], 'w', encoding='utf8') as f:
        for d in new_test_data:
            f.write(json.dumps(d, ensure_ascii=False) + '\n')


                 

if __name__ == '__main__':
    get_evidence_data(["./train.jsonl", "./valid.jsonl", "./test.jsonl"], ["./train_evidence.jsonl", "./valid_evidence.jsonl", "./test_evidence.jsonl"])