from .Trie import HatTrie
from typing import List, Dict, Any, Optional
from tqdm import tqdm
import math
import json


def compute_statistics_info(text: str, trie: HatTrie, N: int = 4, alpha: float = 1, get_t_test: bool = False) -> List[List[float]]:
    text = text.strip()
    statistics_info = []

    assert N > 0, ValueError('N must greater than zero.')

    for i in range(len(text)):
        t_test = []
        aggregation = []
        left_entropy = [trie[text[i]][1] if text[i] in trie else 0.0]
        right_entropy = [trie[text[i]][2] if text[i] in trie else 0.0]

        for j in range(1, N):
            for tmp in range(j + 1):
                if i - j + tmp >= 0 and i + tmp < len(text) and text[i - j + tmp: i + 1 + tmp] in trie:
                    items = trie[text[i - j + tmp: i + 1 + tmp]]
                    aggregation.append(float(items[3]) * alpha)
                else:
                    aggregation.append(0.0)

            if i - j >= 0 and text[i - j: i + 1] in trie:
                items = trie[text[i - j: i + 1]]
                right_entropy.append(float(items[2]))
                # aggregation.append(float(items[3]) * alpha)
            else:
                right_entropy.append(0.0)
                # aggregation.append(0.0)                

            if i + j < len(text) and text[i: i + j + 1] in trie:
                items = trie[text[i: i + j + 1]]
                left_entropy.append(float(items[1]))
                # aggregation.append(float(items[3]) * alpha)
            else:
                left_entropy.append(0.0)
                # aggregation.append(0.0)                    
        
        if get_t_test is True:
            xy = float(trie[text[i - 1: i + 1]][0]) if (i - 1 >= 0 and text[i - 1: i + 1] in trie) else 0.0
            yz = float(trie[text[i: i + 2]][0]) if (i + 1 < len(text) and text[i: i + 2] in trie) else 0.0
            x = float(trie[text[i - 1]][0]) if (i - 1 >= 0 and text[i - 1] in trie) else 1.0
            y = float(trie[text[i]][0]) if text[i] in trie else 1.0
            res = ((yz / y) - (xy / x)) / math.sqrt((yz / y / y) + (xy / x / x) + 1.0)
            t_test.append(res)

        statistics_info.append(aggregation + left_entropy + right_entropy + t_test)
    
    return statistics_info


def convert_corpus2statistics(corpus_path: str, out_path: str, trie: HatTrie) -> None:
    with open(corpus_path, 'r', encoding='utf-8') as corpus, open(out_path, 'w', encoding='utf-8') as o:
        for line in tqdm(corpus):
            label = []
            meta_data = []
            items = line.strip().split()
            for item in items:
                label += [1] + [0] * (len(item) - 1)
                meta_data += list(item)

            statistics_feature = compute_statistics_info(''.join(items), trie)    
            assert len(statistics_feature) == len(label), ValueError(f'please check sentence: {line}')

            for feature, tag, c in zip(statistics_feature, label, meta_data):
                o.write(json.dumps({'feature': feature, 'tag': tag, 'meta_data': c}, ensure_ascii=False) + '\n')
            o.write('\n')


def compute_prob_from_statistics(text: str, trie: HatTrie):
    mean_entropy = [0.5844165179499213, 0.3356532034636478, 0.1799745334967306, 0.054247150612504874, 0.5718778603979453, 0.3135597018123848, 0.12084713658312636, 0.1370869802059968]
    left_PMI_weights = [0.5, -0.5 / 2, 0.33, 0.17, -0.33 / 2, 0.25, -0.125 / 2, 0.125, -0.25 / 2]
    right_PMI_weights = [-0.5 / 2, 0.5, -0.33 / 2, 0.17, 0.33, -0.25 / 2, 0.125, -0.125 / 2, 0.25]
    entropy_2_pmi = [-1, 1, 4, 8, -1, 0, 2, 5]
    # entropy_weights = [1, 0.5, 0.25, 0.125, 1, 0.5, 0.25, 0.125]
    entropy_weights = [0.25] * 8
    # entropy_weights = [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]

    weights = [-0.5473,  0.1970,  0.0166,  0.0544,  0.0192, -0.0867,  0.0115, -0.1588,  0.2158,  0.1837,  0.3819,  0.2598,  0.1431,  0.4220,  0.1426,  0.3408, -0.1773]
    bias = 0
    # weights = [-0.5692,  0.1830, -0.0161,  0.0322,  0.0008, -0.0990,  0.0084, -0.1640,  0.2086,  0.1449,  0.3682,  0.2507,  0.1799,  0.3898,  0.0920,  0.3401, -0.1584]
    # bias = 0.3436
    # weights = [-0.9519,  0.4437, -0.2534, -0.2536,  0.1615, -0.1089,  0.2335, -0.2801,  0.1137,  0.0458,  0.5948,  0.5689,  0.0593,  0.4672,  0.7049,  1.5863,  0.0789]
    # bias = 1.0448

    statistics_info = compute_statistics_info(text, trie)
    probs = []

    for statistics_feature in statistics_info:
        # prob = 0.0
        # for w, feature in zip(weights, statistics_feature):
        #     prob += w * float(feature)
        # probs.append(prob + bias)

        left_prob, right_prob = 0.0, 0.0
        for w, pmi in zip(left_PMI_weights, statistics_feature[:9]):
            left_prob += w * float(pmi)
        # print('1: ', prob)
        for idx, w, entropy, mean in zip(entropy_2_pmi[:4], entropy_weights[:4], statistics_feature[9:13], mean_entropy[:4]):
            if idx < 0:
                left_prob -= float(entropy) - mean
            else:
                left_prob -= float(statistics_feature[idx]) / 10 * (float(entropy) - mean)
        # for idx, w, entropy, mean in zip(entropy_2_pmi[4:], entropy_weights[4:], statistics_feature[13:], mean_entropy[4:]):
        #     if idx < 0:
        #         left_prob += (float(entropy) - mean) / 2
        #     else:
        #         left_prob += float(statistics_feature[idx]) / 10 * (float(entropy) - mean) / 2

        
        for w, pmi in zip(right_PMI_weights, statistics_feature[:9]):
            right_prob += w * float(pmi)
        # print('1: ', prob)
        for idx, w, entropy, mean in zip(entropy_2_pmi[4:], entropy_weights[4:], statistics_feature[13:], mean_entropy[4:]):
            if idx < 0:
                right_prob -= float(entropy) - mean
            else:
                right_prob -= float(statistics_feature[idx]) / 10 * (float(entropy) - mean)
        # for idx, w, entropy, mean in zip(entropy_2_pmi[:4], entropy_weights[:4], statistics_feature[9:13], mean_entropy[:4]):
        #     if idx < 0:
        #         right_prob += (float(entropy) - mean) / 2
        #     else:
        #         right_prob += float(statistics_feature[idx]) / 10 * (float(entropy) - mean) / 2
        # print('2: ', prob)
        # probs.append(prob + bias)
        probs.append((left_prob, right_prob))

    return probs, statistics_info
