import os
import logging
import random
import numpy as np

from collections import Counter
from .ioFn import readJsonl

logger = logging.getLogger()

def createHashMap(datas: list, key: str):
    results = {}
    for data in datas:
        newkey = data[key].strip().lower()
        newkey = normalize(newkey)
        results[newkey] = data
    return results

def findNN(query: str, keys: list):
    max_f1, nn = 0.0, ""
    qtokens = query.split()
    qtokens = set(qtokens)
    for key in keys:
        ktokens = set(key.split())
        overlap = qtokens.intersection(ktokens)
        precision = len(overlap) / (len(ktokens) + 1e-3)
        recall = len(overlap) / (len(qtokens) + 1e-3)
        f1 = 2 * precision * recall / (precision + recall + 1e-3)
        if f1 > max_f1:
            max_f1 = f1
            nn = key
    return nn

def normalize(s: str):
    s = s.replace("\\", "")
    s = s.replace("\"", "")
    s = s.replace(".", "")
    s = s.strip().lower()
    s = " ".join(s.split())
    return s

def matchEmb(data_infos, input_dir, embkey, balance=False):
    """
    return the lsit of sent embedding info
    """
    emb_infos = readJsonl(os.path.join(input_dir, embkey))
    hashed_emb_info = createHashMap(emb_infos, "sent")

    missed_num = 0
    missed_sents = []
    all_num = 0
    matched_infos = []
    for dinfo in data_infos:
        sents = dinfo['document']
        doc_id = dinfo['id']
        oracles = dinfo['label']
        all_num += len(sents)
        for (i, sent) in enumerate(sents):
            sent = sent.strip().lower()
            sent = normalize(sent)
            if sent != "":
                emb_info = hashed_emb_info.get(sent, None)
                if emb_info is None:
                    missed_num += 1
                    missed_sents.append([sent, i])
                else:
                    emb_info['id'] = "{}_{}".format(doc_id, i)
                    emb_info['is_oracle'] = (i in oracles)
                    matched_infos.append(emb_info)

    for item in missed_sents[:10]:
        missed_sent = item[0]
        nn = findNN(missed_sent, list(hashed_emb_info.keys()))
        logger.warning("missed_sent: {}".format(missed_sent))
        logger.warning("nearest neighbour: {}\n".format(nn))

    # exit(0)
    cnt = Counter()
    for item in matched_infos:
        doc_id = item['id'].split('_')[0]
        cnt[doc_id] += 1

    logger.warning("number of documents: {}".format(len(cnt.keys())))
    logger.warning("number of missed sent: {}".format(missed_num))
    logger.warning("number of matched sent: {}".format(len(matched_infos)))

    positives, negatives = [], []
    for item in matched_infos:
        if item['is_oracle'] == 1:
            positives.append(item)
        else:
            negatives.append(item)
    number_of_positive = len(positives)

    if balance:
        size = min(len(positives), len(negatives))
        positives = positives[:size]
        negatives = negatives[:size]
        matched_infos = positives + negatives
    logger.warning("label 1: {}, label 0: {}".format(
        number_of_positive,
        len(matched_infos) - number_of_positive
    ))
    return matched_infos

def centeringFn(data_infos):
    """
    centering the embedding by subtracting its centroid
    """
    embeddings = []
    for info in data_infos:
        embeddings.append(info['embedding'])
    embeddings = np.array(embeddings, dtype=np.float32)
    centroid = np.mean(embeddings, axis=0)
    logger.warning("centroid.shape: {}".format(centroid.shape))

    for (di, info) in enumerate(data_infos):
        data_infos[di]['embedding'] = info['embedding'] - centroid
    return data_infos