import os
import sys, getopt
import random
import json
from time import time
from tqdm import tqdm
from pprint import pprint
import datasets
from datasets import list_datasets, load_dataset
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from typing import List
import stanza
from stanza.models.common.doc import Document
from stanza_batch import batch
from rank_bm25 import BM25Okapi
from utils import *

random.seed(1234)
sent_pool = []
max_neg_sam_num = 10


def init_sample(input_data):
    for i in input_data:
        para = i['para']
        sent_dic = i['token']
        span_list_dic = i['event']

        sent_neg_dic = {}
        temp = []

        for k in sent_dic.keys():
            sentence = para[int(k)]

            event_neg_list = []

            span_list = span_list_dic[k]
            for span in span_list:
                neg_sam_cand_word = []
                span_dic = {}

                event = sent_dic[k][span[0]:span[1]]
                event_str = list2string(event)
                temp.append(event)
        sent_pool.append(temp)
    return 

def negative_sampling(input_data):
    doc_list = []

    for ids, i in enumerate(input_data):
        para = i['para']
        sent_dic = i['token']
        span_list_dic = i['event']
        
        sent_neg_dic = {}

        if ids < 5:
            min_sent_pool = sent_pool[:10]
        elif ids >= len(input_data) - 5:
            min_sent_pool = sent_pool[-10:]
        else:
            min_sent_pool = sent_pool[ids-4:ids+6]

        temp = []
        for pre_sent in min_sent_pool:
            temp += pre_sent

        for k in sent_dic.keys():
            sentence = para[int(k)]

            event_neg_list = []

            span_list = span_list_dic[k]
            for span in span_list:
                neg_sam_cand_word = []
                neg_sam_cand_pos = []

                event = sent_dic[k][span[0]:span[1]]
                event_str = list2string(event)

                # word search
                bm25 = BM25Okapi(temp)
                results = bm25.get_top_n(event, temp, n=12)
                for pre in results:
                    if len(neg_sam_cand_word) >= max_neg_sam_num:
                        break
                    pre = list2string(pre)
                    if (pre != event_str):
                        neg_sam_cand_word.append(pre)

                event_neg_list.append([event_str, neg_sam_cand_word])

            sent_neg_dic[k] = [sentence, event_neg_list]

        doc_list.append(sent_neg_dic)
        
    return doc_list



if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'


    for i in tqdm(range(len(100000))):
        try:
            input_data = read_json('./samples/%s.json' %str(i))
        except:
            continue

        init_sample(input_data)
        neg_sam_doc = negative_sampling(input_data)
        output_file_name = './negative_samples_context/%s.json' %str(i)
        write_json(output_file_name, neg_sam_doc)
        