import os
import getopt
import random
import json
import time
from tqdm import tqdm
from pprint import pprint
import datasets
from datasets import list_datasets, load_dataset
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from typing import List
import stanza
from stanza.models.common.doc import Document
from stanza_batch import batch
from utils import *


nlp = None
conj_exp_dict = None

noise_start = ['_', '-', '*', '#']
noise_token = '"'
sequence_num = 2
pre_token_num = 7
random.seed(1234)


def parsing_filter(paragraph_list, conj_candidate_list):
    para_sent_count = []
    all_sent_list = []

    stanza_documents: List[Document] = []
    for document in batch(paragraph_list, nlp, batch_size=256): # Default batch size is 32
        stanza_documents.append(document)

    parsing_result = []
    parsing_sent = []
    conj_cand = []
    for ids, para in enumerate(stanza_documents):
        para_conj_cand = conj_candidate_list[ids]
        para_results = []
        para_sent = []
        para_conj_list = []

        for sentence in para.sentences:
            sent_low = str.lower(' ' + sentence.text + ' ')
            sign = 0
            for k in para_conj_cand:
                if k in sent_low:
                    sign = 1
                    break

            para_sent.append(sentence.text)

            temp = []
            if sign == 1:
                start_id = int(getstartandend(sentence.words[0].misc)[0])
                for word in sentence.words:
                    misc = 'start_char=' + str(int(getstartandend(word.misc)[0]) - start_id) + '|end_char=' + str(int(getstartandend(word.misc)[1]) - start_id)
                    temp.append([word.text, word.lemma, word.pos, word.head-1, misc, word.deprel])
            para_results.append(temp)
        
        for k in para_conj_cand:
            temp_val = conj_exp_dict[k]
            para_conj_list.append(conj_token[temp_val])

        conj_cand.append(para_conj_list)
        parsing_result.append(para_results)
        parsing_sent.append(para_sent)

    return parsing_result, parsing_sent, conj_cand

def parsing_filter_further(para_results_list, para_sent_list, conj_cand):
    doc_list = []
    doc_conj_list = []
    doc_conj_sent_token_list = []
    doc_event_list = []
    
    for para_id, para in enumerate(para_sent_list):
        para_conj_list = []
        para_conj_sent_token_dic = {}
        para_event_dic = {}

        para_conj_cand_list = conj_cand[para_id]
        para_results = para_results_list[para_id]

        for sent_id, sent_result in enumerate(para_results):
            if len(sent_result) == 0:
                continue

            sent_lemma_list = []
            for i in sent_result:
                sent_lemma_list.append(i[1])

            sent_token_list = []
            for i in sent_result:
                sent_token_list.append(i[0])

            temp = []
            for i in para_conj_cand_list:
                tmp = find_token_id(sent_lemma_list, i)
                if len(tmp) > 0:
                    temp.append(tmp)
            temp = del_overlap(temp)

            if len(temp) == 0:
                continue
               
            real_conj_list = []
            for i in temp:
                sign = 0
                for j in range(i[0], i[1]):
                    head = sent_result[j][3]
                    if sent_result[head][2] == 'VERB':
                        sign = 1
                        break

                if sign == 1:
                    real_conj_list.append(i)
                    #tmp_text = ''
                    tmp_lemma = ''
                    #tmp_misc = deleteByStartAndEnd(sent_result[i[0]][4], sent_result[i[1]-1][4], '|')
                    for j in range(i[0], i[1]):
                    #    tmp_text = tmp_text + sent_result[j][0] + ' '
                        tmp_lemma = tmp_lemma + sent_result[j][1] + ' '
                    #tmp_text = tmp_text[:-1]
                    tmp_lemma = tmp_lemma[:-1]
                    #para_conj_list.append([sent_id, tmp_text, tmp_lemma, tmp_misc])

                    para_conj_list.append([sent_id, i, tmp_lemma])

            if len(real_conj_list) > 0:
                para_conj_sent_token_dic[str(sent_id)] = sent_token_list
                para_event_dic[str(sent_id)] = event_span_extraction(sent_result, real_conj_list)

        if len(para_conj_list) > 0:
            doc_list.append(para)
            doc_conj_list.append(para_conj_list)     
            doc_conj_sent_token_list.append(para_conj_sent_token_dic)  
            doc_event_list.append(para_event_dic)  

    return doc_list, doc_conj_list, doc_conj_sent_token_list, doc_event_list

def filter_sentence(book):
    paragraph_list = []
    conj_candidate_list = []

    for para in book:
        if len(para) == 0:
            continue
        # filter by start word
        if para[0] in noise_start:
            continue
        if noise_token in para:
            continue
        
        # filter by sentence number
        para_sent_list = sent_tokenize(para)
        if len(para_sent_list) < sequence_num:
            continue
        
        # filter by token number
        min_token = len(para_sent_list) * pre_token_num
        tmp = list2string(para_sent_list)
        if len(word_tokenize(tmp)) < min_token:
            continue

        # remove _
        tmp = remove_(tmp)

        # conj_exp_dict filter
        sent_tmp = ' ' + tmp + ' '
        temp = []
        for k in conj_exp_dict.keys():
            if k in str.lower(sent_tmp):
                temp.append(k)
 
        if len(temp) == 0:
            continue

        paragraph_list.append(tmp)
        conj_candidate_list.append(temp)

    # conj filter
    parsing_result, parsing_sent, conj_cand = parsing_filter(paragraph_list, conj_candidate_list)
    doc_list, doc_conj_list, doc_conj_sent_token_list, doc_event_list = parsing_filter_further(parsing_result, parsing_sent, conj_cand)

    return doc_list, doc_conj_list, doc_conj_sent_token_list, doc_event_list



if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    stanza.download('en')
    nlp = stanza.Pipeline(lang="en", processors='tokenize,pos,lemma,depparse')
    conj_exp_dict = expand_conj_dict(conj_dict)
    
    dataset = load_dataset('bookcorpusopen')

    pbar = tqdm(total=len(dataset))
    for ids, corpus in enumerate(dataset):
        #time_start = time.time()

        book = corpus["text"]
        book = book.split('\n')
        doc_list, doc_conj_list, doc_conj_sent_token_list, doc_event_list = filter_sentence(book)

        #time_end=time.time()
        #print('time cost',time_end-time_start,'s')
        
        if len(doc_list) == 0:
            continue

        data_list = []
        for i in range(len(doc_list)):
            data_list.append({"para": doc_list[i], "conj": doc_conj_list[i], "token": doc_conj_sent_token_list[i], "event": doc_event_list[i]})

        if len(data_list) > 0:
            #print('============= write file =============')
            #time_start = time.time()

            output_file_name = './samples/%s.json' % str(ids)
            write_json(output_file_name, data_list)
            
            #time_end=time.time()
            #print('time cost',time_end-time_start,'s')

        pbar.update(1)
    pbar.close()












