from typing import List, Dict, Any, Optional
from collections import defaultdict
import hat_trie
import threading
import time
import math
import argparse
import json
import os
from tqdm import tqdm


def is_all_chinese(strs):
    for _char in strs:
        if not '\u4e00' <= _char <= '\u9fa5':
            return False
    return True


class Trie(object):
    def __init__(self) -> None:
        self.root = defaultdict(Trie)
        self.form: str = ''
        self.freq: int = 0
    
    def add(self, word: str) -> None:
        tree = self
        for c in word:
            tree = tree.root[c]

        tree.form = word
        tree.freq += 1
    
    def find(self, word: str) -> int:
        tree = self
        for c in word:
            tree = tree.root[c]

        return tree.freq
    
    def convert2dict(self) -> Dict[str, int]:
        def dfs(tree: Trie) -> Dict[str, int]:
            if tree.form == '':
                dic = {}
            else:
                dic = {tree.form: tree.freq}

            for sub_tree in tree.root.values():
                dic.update(dfs(sub_tree))

            return dic

        return dfs(self)

    def save(self, save_path: str) -> None:
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(self.convert2dict(), f)

    @staticmethod
    def load(self, load_path: str) -> None:
        pass


class HatTrie(hat_trie.Trie):
    def __init__(self, **kwargs) -> None:
        super(HatTrie, self).__init__(**kwargs)

    def _count_left_right(self):
        for key in self.iterkeys():
            freq = self[key][0]
            if len(key) == 1:
                continue
            self[key[1:]][1][key[0]] = freq
            self[key[:-1]][2][key[-1]] = freq

    def _count_aggregation(self, statistics: Dict) -> None:
        for key in self.iterkeys():
            aggregation = 1.0 #float(1e10)
            for i in range(1, len(key)):
                tmp_agg = math.log2( (self[key][0] / statistics[len(key)]) / ((self[key[:i]][0] / statistics[i]) * (self[key[i:]][0] / statistics[len(key) - i])) )
                tmp_agg /= (-math.log2((self[key][0] / statistics[len(key)])))
                aggregation = min(aggregation, tmp_agg)

            self[key].append(aggregation)

    def _normalization_entropy(self) -> None:
        max_entropy = {}
        min_entropy = {}

        for key in self.iterkeys():
            items = self[key]

            if len(key) not in max_entropy:
                max_entropy[len(key)] = (items[1], items[2])
                min_entropy[len(key)] = (items[1], items[2])
            else:
                max_entropy[len(key)] = (max(items[1], max_entropy[len(key)][0]), max(items[2], max_entropy[len(key)][1]))
                min_entropy[len(key)] = (min(items[1], min_entropy[len(key)][0]), min(items[2], min_entropy[len(key)][1]))              

        # max_entropy = {3: (4.94112463573756, 5.400343333448451), 1: (6.412766622106885, 6.552137067104514), 2: (5.603643252837186, 5.964610545514368), 4: (4.45121607206011, 4.024237409849223)}
        # min_entropy = {3: (0.0, 0.2350272892027869), 1: (0.0, 0.0), 2: (0.0, 0.0), 4: (0.3527647119321118, 0.0)}

        for key in self.iterkeys():
            if len(key) > 4:
                continue
            items = self[key]
            _left = (items[1] - min_entropy[len(key)][0]) / (max_entropy[len(key)][0] - min_entropy[len(key)][0])
            _right = (items[2] - min_entropy[len(key)][1]) / (max_entropy[len(key)][1] - min_entropy[len(key)][1])

            if len(items[3]) == 0:
                _left = 0.0
            if len(items[4]) == 0:
                _right = 0.0

            if len(self[key]) == 5:
                self[key] = [items[0], _left, _right]
            else:
                self[key] = [items[0], _left, _right, items[-1]]

    def _convert2entropy(self) -> None:
        for key in self.iterkeys():
            (freq, lefts, rights) = self[key]
            left_entropy, right_entropy = 0.0, 0.0

            for k, v in lefts.items():
                p = 1.0 * v / freq
                left_entropy -= p * math.log2(p)

            for k, v in rights.items():
                p = 1.0 * v / freq
                right_entropy -= p * math.log2(p)

            self[key] = [int(freq), left_entropy, right_entropy, lefts, rights]

    def convert2dict(self) -> Dict[str, int]:
        dic = {}
        for key in self.keys():
            if self[key] < 10:
                continue
            dic[key] = self[key]

        return dic
    
    def add(self, word: str) -> None:
        if word in self:
            self[word] += 1
        else:
            self[word] = 1

    def save(self, save_path: str, thresold: int = 10) -> None:
        # self._count_left_right()
        # self._convert2entropy()

        with open(save_path, 'w', encoding='utf-8') as f:
            # json.dump(self.convert2dict(), f, indent=1, ensure_ascii=False)
            for key in self.iterkeys():
                if self[key][0] > thresold and len(key) <= 4:
                    f.write(json.dumps({key: self[key]}, ensure_ascii=False) + '\n')

    @staticmethod
    def load(load_path: str, only_freq: bool = False):
        statistics = {}
        with open(load_path, 'r', encoding='utf-8') as f:
            tree = HatTrie()
            for line in f:
                tmp_dic = json.loads(line.strip())
                for k, v in tmp_dic.items():
                    if len(k) not in statistics:
                        statistics[len(k)] = 0.0
                    if only_freq is True:
                        statistics[len(k)] += v
                        tree[k] = [v, {}, {}]
                    else:
                        statistics[len(k)] += v[0]
                        tree[k] = v
        
        print(f'load successful, statistics info: {statistics}')
        return tree, statistics


def merge_tries(input_path: str, output_file: str) -> None:
    files = os.listdir(input_path)

    tree = HatTrie()
    with open(output_file, 'w', encoding='utf-8') as o:
        for file in files:
            tmp_trie, _ = HatTrie.load(input_path + file)
            for key in tmp_trie.keys():
                if key in tree:
                    tree[key] += tmp_trie[key]
                else:
                    tree[key] = tmp_trie[key]
    
    tree.save(output_file, thresold=50)


def count_ngram_from_file(input_path: str, save_path: str, n: int = 4) -> None:
    import re
    import string
    from zhon.hanzi import punctuation as chinese_punctuation
    tree = HatTrie()
    chi_punc = '|'.join([c for c in chinese_punctuation])
    eng_punc = '|'.join([c for c in string.punctuation])
    punc = chi_punc + eng_punc
    punc = punc[:-6] + punc[-4:] + '|．|︰|-|𤞤'

    with open(input_path, 'r', encoding='utf-8') as f:
        for cnt, line in enumerate(tqdm(f)):
            # if cnt > 100000:
            #     break
            line = line.strip().replace('\\', ' ').split()
            if len(line) == 0:
                continue
            line = '，'.join(line)
            texts = re.split(r'' + f"[{punc}]", line)
            for text in texts:
                add_num = 1
                if 1 < len(text) < n and is_all_chinese(text) is True:
                    add_num = 1
                for i in range(len(text)):
                    for j in range(1, n + 1):
                        if i + j > len(text):
                            break
                        if text[i: i + j] in tree:
                            tree[text[i: i + j]] += add_num
                        else:
                            tree[text[i: i + j]] = add_num
                        # if i > 0:
                        #     if text[i - 1] in tree[text[i: i + j]][1]:
                        #         tree[text[i: i + j]][1][text[i - 1]] += 1
                        #     else:
                        #         tree[text[i: i + j]][1][text[i - 1]] = 1
                        # if i + j < len(text):
                        #     if text[i + j] in tree[text[i: i + j]][2]:
                        #         tree[text[i: i + j]][2][text[i + j]] += 1
                        #     else:
                        #         tree[text[i: i + j]][2][text[i + j]] = 1


        tree.save(save_path)


def count_certain_gram_from_file(input_path: str, save_path: str, n: int = 4) -> None:
    import re
    import string
    from zhon.hanzi import punctuation as chinese_punctuation
    tree = HatTrie()
    chi_punc = '|'.join([c for c in chinese_punctuation])
    eng_punc = '|'.join([c for c in string.punctuation])
    punc = chi_punc + eng_punc
    punc = punc[:-6] + punc[-4:] + '|．|︰|-|𤞤'

    with open(input_path, 'r', encoding='utf-8') as f:
        for cnt, line in enumerate(tqdm(f)):
            # if cnt > 100000:
            #     break
            line = line.strip().replace('\\', ' ').split()
            if len(line) == 0:
                continue
            line = '，'.join(line)
            texts = re.split(r'' + f"[{punc}]", line)
            for text in texts:
                for i in range(len(text) - 4):
                    if text[i: i + 5] in tree:
                        tree[text[i: i + 5]] += 1
                    else:
                        tree[text[i: i + 5]] = 1

        tree.save(save_path)


# class TrieThread(threading.Tread):
#     def __init__(self, file_handle) -> None:
#         threading.Thread.__init__(self)
#         # self.file_handle = file_handle

#     def run(self):
#         while True:
#             Lock.acquire()
#             line = file_handle.readline()
#             if line is None:
#                 break
#             Lock.release()

#             line = line.strip().replace('\\', ' ').split()
#             if len(line) == 0:
#                 continue
#             line = '，'.join(line)
#             texts = re.split(r'' + f"[{punc}]", line)
#             for text in texts:
#                 for i in range(len(text) - 4):
#                     if text[i: i + 5] in tree:
#                         tree[text[i: i + 5]] += 1
#                     else:
#                         tree[text[i: i + 5]] = 1


#     def 


def read_one_line(text: str, tree, lock):
    text = text.strip()
    n = 4
    for i in range(len(text)):
        for j in range(1, n + 1):
            if i + j > len(text):
                break
            lock.acquire()
            tree.add(text[i: i + j])
            lock.release()


def parallel_count_ngram_from_file(input_path: str, save_path: str) -> None:
    from multiprocessing import managers, Pool, Lock, Manager
    from functools import partial

    with open(input_path, 'r', encoding='utf-8') as f:
        datas = f.readlines()

    manager = managers.BaseManager()
    manager.register('HatTrie', HatTrie)
    manager.start()
    share_tree = manager.HatTrie()
    share_lock = Manager().Lock()
    sub_process_func = partial(read_one_line, tree=share_tree, lock=share_lock)
    pool = Pool()
    pool.map(sub_process_func, tqdm(datas))
    share_tree.save(save_path)
    pool.close()
    pool.join()


def semi_vocab(input_file: str, output_file: str, vocab_file: str):
    tree, statistics = HatTrie.load(input_file)

    with open(vocab_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            word = line.strip()
            if word in tree:
                tree[word][1] = 1
                tree[word][2] = 1
            else:
                tree[word] = [100, 1, 1, 0.5]

    tree.save(output_file)
