import os
import sys, getopt
import random
import json
from time import time
from tqdm import tqdm
from pprint import pprint
import datasets
from datasets import list_datasets, load_dataset
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from typing import List
import stanza
from stanza.models.common.doc import Document
from stanza_batch import batch
from utils import *

random.seed(1234)

# data_dict['0'] = {'pos_str': pos_dic}
# pos_dic = {'token': sent_id_list}
max_neg_sam_num = 10
all_sent_dic = {}
all_sent_len = {}
data_dict = {}
count = 0


# pos dic all books
def bulid_event_dic(input_data):
    global count
    book_dic = {}

    for i in input_data:
        sent_dic = i['token']
        span_list_dic = i['event']
        
        for k in sent_dic.keys():
            span_list = span_list_dic[k]
            for span in span_list:
                event = sent_dic[k][span[0]:span[1]]   # tokens
                event_str = list2string(event)
                pos_str = tokens2pos_seq(event)

                sent_id = str(count)
                all_sent_dic[sent_id] = event_str
                all_sent_len[sent_id] = len(event)
            
                if pos_str not in book_dic:
                    book_dic[pos_str] = {}

                for token in event:
                    if token not in book_dic[pos_str]:
                        book_dic[pos_str][token] = [sent_id]
                    else:
                        book_dic[pos_str][token] += [sent_id]

                count += 1
                
    return book_dic


def negative_sampling(input_data, doc_id, all_doc_num):
    all_id = [str(i) for i in range(all_doc_num)]
    all_id.remove(str(doc_id))
    random.shuffle(all_id)
    all_id = [str(doc_id)] + all_id

    doc_list = []

    for i in input_data:
        para = i['para']
        sent_dic = i['token']
        span_list_dic = i['event']
        
        sent_neg_dic = {}

        for k in sent_dic.keys():
            sentence = para[int(k)]

            event_neg_list = []

            span_list = span_list_dic[k]
            for span in span_list:
                neg_sam_cand_word = []
                span_dic = {}

                event = sent_dic[k][span[0]:span[1]]
                event_str = list2string(event)
                pos_str = tokens2pos_seq(event)

                for doc_id in all_id:
                    if doc_id in data_dict:
                        if len(span_dic) >= max_neg_sam_num:
                            break

                        if pos_str in data_dict[doc_id]:
                            for token in event:
                                if token in data_dict[doc_id][pos_str]:
                                    for ids in data_dict[doc_id][pos_str][token]:
                                        if ids not in span_dic:
                                            span_dic[ids] = 1/(all_sent_len[ids] + abs(len(event) - all_sent_len[ids]))
                                        else:
                                            span_dic[ids] = span_dic[ids] + 1/(all_sent_len[ids] + abs(len(event) - all_sent_len[ids]))

                span_dic = sorted(span_dic.items(), key=lambda x:x[1], reverse=True)
                #span_dic = output_dict_thed(span_dic)
                
                for ids in span_dic:
                    if len(neg_sam_cand_word) >= max_neg_sam_num:
                        break
                    cand_sent = all_sent_dic[ids[0]]
                    if (cand_sent != event_str) and (cand_sent not in neg_sam_cand_word):
                        neg_sam_cand_word.append(cand_sent)
              
                event_neg_list.append([event_str, neg_sam_cand_word])

            sent_neg_dic[k] = [sentence, event_neg_list]

        doc_list.append(sent_neg_dic)
        
    return doc_list



if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    for i in tqdm(range(100000)):
        try:
            input_data = read_json('./samples/%s.json' %str(i))
        except:
            continue
        
        data_dict[str(i)] = bulid_event_dic(input_data)

    output_file_name = 'pos_dic.json'
    write_json(output_file_name, [all_sent_dic, all_sent_len, data_dict])

    st = time()
    all_sent_dic, all_sent_len, data_dict = read_json('pos_dic.json')
    et = time()
    print('Load dict:', et-st)

    for i in tqdm(range(100000)):
        try:
            input_data = read_json('./samples/%s.json' %str(i))
        except:
            continue

        neg_sam_doc = negative_sampling(input_data, i, 18000)
        output_file_name = './negative_samples_pos/%s.json' %str(i)
        write_json(output_file_name, neg_sam_doc)
        