import os
import json
from nltk import sent_tokenize, word_tokenize, pos_tag
from nltk.stem import WordNetLemmatizer

from nltk.text import TextCollection
from nltk.corpus import wordnet
from multiprocessing import Pool
from itertools import chain


def get_idf_sents():
    def get_idf_docs(mode='../data/train_peak_onecard', last=True):
        suffix = '_doc_idf_only_last' if last else '_doc_idf'
        out_file_name = '{}{}.json'.format(mode, suffix)
        if os.path.exists(out_file_name):
            with open(out_file_name, encoding='utf-8') as f:
                return json.load(f)
        file_name = '{}.json'.format(mode)
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:
                if last:
                    hyps.append(scene['entries'][-1]['description'])
                else:
                    for entry in scene['entries']:
                        hyps.append(entry['description'])
                        # hyp += entry['description']
                # for card in scene['entries'][-1]['cards']:
                #     refs.append(card['description'])
            with open(out_file_name, 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')

            return hyps + refs

    def get_idf_split_sent():
        if os.path.exists('train_sent_idf.json'):
            with open('train_sent_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = 'train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:

                for entry in scene['entries']:
                    # hyps.append(entry['description'])
                    hyps.extend(sent_tokenize(entry['description']))
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    # refs.append(card['description'])
                    refs.extend(sent_tokenize(card['description']))

            with open('train_sent_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')
            return hyps + refs

    # return get_idf_split_sent()
    return get_idf_docs()
    # return get_valid_idf_docs()

wnl = WordNetLemmatizer()

def lemmatize(word, tag):
    def get_wordnet_pos(tag):
        if tag.startswith("J"):
            return wordnet.ADJ
        elif tag.startswith("V"):
            return wordnet.VERB
        elif tag.startswith("N"):
            return wordnet.NOUN
        elif tag.startswith("R"):
            return wordnet.ADV
        else:
            return wordnet.NOUN
    word = word.lower()
    tag = get_wordnet_pos(tag)
    word = wnl.lemmatize(word, tag)
    return word

def tokenize(sent):
    words = word_tokenize(sent)
    tagged_words = pos_tag(words)
    tagged_words = [lemmatize(i[0], i[1]) for i in tagged_words]
    return tagged_words


idf_sents = get_idf_sents()
# pool = Pool(os.cpu_count())
print('finish load idf sents')
with Pool(os.cpu_count()) as pool:
    rel = pool.map(tokenize, idf_sents)
print('finish word tokenize')
print('doc num:', len(rel))
word_set = list(set(chain(*rel)))
print('word num:', len(word_set))
text_collection = TextCollection(rel)
word_idfs = {}

cnt = 0
def work(word):
    global cnt
    cnt += 1
    if cnt % 100 == 0:
        print(cnt)
    return text_collection.idf(word)



print('start pool compute idf')
with Pool(os.cpu_count()) as pool:
    counts = pool.map(work, word_set)
# for word in text_collection:
#     word_idfs[word] = text_collection.idf(word)
for word, count in zip(word_set, counts):
    word_idfs[word] = count
word_idfs = sorted(word_idfs.items(), key=lambda item: item[1])
with open('../data/word_idf.json', 'w', encoding='utf-8') as f:
    json.dump(word_idfs, f, ensure_ascii=False)
