import json
import os
import sys, getopt
from tqdm import tqdm
from time import time
import random
import torch
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader, RandomSampler
from pprint import pprint
from time import time
from utils import *
import sys
sys.path.append("..")
from transformers import BartTokenizerFast


# generate dataset
min_para_length = 21
max_para_legnth = 128
min_event_length = 3
max_event_legnth = 15
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large')
train_dic, eval_dic = sample_distribution()
train_dic, _ = sample_number(train_dic)
train_end_sample_number = 0
eval_end_sample_number = 0
train_conj_priority = {}
eval_conj_priority = {}

def get_conj_priority():
    global train_end_sample_number
    global eval_end_sample_number

    train_conj_list = sort_dict(train_dic, reverse=False)
    for i,j in train_conj_list:
        train_conj_priority[i] = int(j)
        train_end_sample_number += int(j)

    eval_conj_list = sort_dict(eval_dic, reverse=False)
    for i,j in eval_conj_list:
        eval_conj_priority[i] = int(j)
        eval_end_sample_number += int(j)

    train_end_sample_number = int(train_end_sample_number/2)
    eval_end_sample_number = int(eval_end_sample_number/2)

def control_paragraph_length(para_list):
    string = list2string(para_list)
    inputs = tokenizer(string)
    length = len(inputs['input_ids'])
    if min_para_length <= length <= max_para_legnth:
        return True 
    else:
        return False

def control_event_length(event_dict, token_dict, ng1_list, ng2_list, ng3_list):
    dic = {}
    event_number = 0
    end_number = 0
    end_list = []
    for k, para_event_list in event_dict.items():
        temp = []
        for list_id, i in enumerate(para_event_list):
            if len(ng1_list[k][1][list_id][1]) < 1:
                continue
            if len(ng2_list[k][1][list_id][1]) < 1:
                continue
            if len(ng3_list[k][1][list_id][1]) < 1:
                continue
            length = int(i[1]) - int(i[0])
            if min_event_length <= length <= max_event_legnth:
                if len(token_dict[k]) == (int(i[1]) + 1):
                    end_number += 1
                    end_list.append([k,len(temp)])
                temp.append([i, [ng1_list[k][1][list_id][1], ng2_list[k][1][list_id][1], ng3_list[k][1][list_id][1]]])
                event_number += 1
        if len(temp) > 0:
            dic[k] = temp
 
    if len(dic) > 0:
        return dic, event_number, end_number, end_list
    else:
        return False, event_number, end_number, end_list

def priority_conj(conj_list):
    para_conj = None
    priority_count = 100001

    if len(conj_list) < 0:
        return False, 1

    # train
    for i in conj_list:
        if train_conj_priority[i[2]] < priority_count and train_dic[i[2]] > 0:
            priority_count = train_conj_priority[i[2]]
            para_conj = i[2]
    if para_conj != None:
        train_dic[para_conj] -= 1
        return para_conj, 1

    # eval
    for i in conj_list:
        if eval_conj_priority[i[2]] < priority_count and eval_dic[i[2]] > 0:
            priority_count = eval_conj_priority[i[2]]
            para_conj = i[2]
    if para_conj != None:
        eval_dic[para_conj] -= 1
        return para_conj, 0
    
    return False, 1

def generate_data():
    random.seed(1234)
    get_conj_priority()

    train_dataset = []
    eval_dataset = []

    sample_id_list = [i for i in range(18000)]
    random.shuffle(sample_id_list)

    for i in tqdm(sample_id_list):
        try:
            samples = read_json('./samples/%s.json' %str(i))
            event_word_ng = read_json('./negative_samples_word/%s.json' %str(i))
            event_pos_ng = read_json('./negative_samples_pos/%s.json' %str(i))
            event_context_ng = read_json('./negative_samples_context/%s.json' %str(i))
        except:
            continue

        for ids, data_dict in enumerate(samples):
            if not control_paragraph_length(data_dict['para']):
                continue
            
            event_dic, event_number, end_number, end_list = control_event_length(data_dict['event'], data_dict['token'], event_pos_ng[ids], event_word_ng[ids], event_context_ng[ids])
            if not event_dic:
                continue

            data_dict['event'] = event_dic
            data_dict['event_num'] = event_number
            data_dict['end_number'] = end_number
            data_dict['end_list'] = end_list
            
            prio_conj, is_train = priority_conj(data_dict['conj'])
            if not prio_conj:
                continue

            data_dict['sign'] = prio_conj
            
            if is_train:
                train_dataset.append(data_dict)
            else:
                eval_dataset.append(data_dict)
    
    print(len(train_dataset))
    print(len(eval_dataset))
    write_json('./train_data.json', train_dataset)
    write_json('./eval_data.json', eval_dataset)


if __name__ == '__main__':
    generate_data()
