import os
import sys, getopt
import random
import json
from time import time
from tqdm import tqdm
from pprint import pprint
import datasets
from datasets import list_datasets, load_dataset
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from typing import List
import stanza
from stanza.models.common.doc import Document
from stanza_batch import batch
from utils import *

random.seed(1234)

max_neg_sam_num = 10


# dic pre books
def bulid_event_dic(input_data):
    count = 0
    all_sent_dic = {}   
    all_sent_len = {}
    token2sent_dic = {}

    for i in input_data:
        sent_dic = i['token']
        span_list_dic = i['event']
        
        for k in sent_dic.keys():
            span_list = span_list_dic[k]
            for span in span_list:
                event = sent_dic[k][span[0]:span[1]]   # tokens
                event_str = list2string(event)

                sent_id = str(count)
                all_sent_dic[sent_id] = event_str
                all_sent_len[sent_id] = len(event)

                for token in event:
                    if token not in token2sent_dic:
                        token2sent_dic[token] = [sent_id]
                    else:
                        token2sent_dic[token].append(sent_id)

                count += 1
    return all_sent_dic, all_sent_len, token2sent_dic


def negative_sampling(input_data, all_sent_dic, all_sent_len, token2sent_dic):
    doc_list = []

    for i in input_data:
        para = i['para']
        sent_dic = i['token']
        span_list_dic = i['event']
        
        sent_neg_dic = {}

        for k in sent_dic.keys():
            sentence = para[int(k)]

            event_neg_list = []

            span_list = span_list_dic[k]
            for span in span_list:
                neg_sam_cand_word = []
                span_dic = {}

                event = sent_dic[k][span[0]:span[1]]
                event_str = list2string(event)

                for token in event:
                    for ids in token2sent_dic[token]:
                        if ids not in span_dic:
                            span_dic[ids] = 1/(all_sent_len[ids] + abs(len(event) - all_sent_len[ids]))
                        else:
                            span_dic[ids] = span_dic[ids] + 1/(all_sent_len[ids] + abs(len(event) - all_sent_len[ids]))

                span_dic = sorted(span_dic.items(), key=lambda x:x[1], reverse=True)
                #span_dic = output_dict_thed(span_dic)
                
                for ids in span_dic:
                    if len(neg_sam_cand_word) >= max_neg_sam_num:
                        break
                    cand_sent = all_sent_dic[ids[0]]
                    if (cand_sent != event_str) and (cand_sent not in neg_sam_cand_word):
                        neg_sam_cand_word.append(cand_sent)
              
                event_neg_list.append([event_str, neg_sam_cand_word])

            sent_neg_dic[k] = [sentence, event_neg_list]

        doc_list.append(sent_neg_dic)
        
    return doc_list



if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    for i in tqdm(range(100000)):
        try:
            input_data = read_json('./samples/%s.json' %str(i))
        except:
            continue
        
        all_sent_dic, all_sent_len, token2sent_dic = bulid_event_dic(input_data)
        neg_sam_doc = negative_sampling(input_data, all_sent_dic, all_sent_len, token2sent_dic)
        output_file_name = './negative_samples_word/%s.json' %str(i)
        write_json(output_file_name, neg_sam_doc)




        