import os
import os.path
import numpy as np
from itertools import compress
from collections import Counter
from faiss_index import create_index

import constants as c
from embeddings import Embeddings
from timer import Timer as PPrintTimer
from load_data import load_queries
from load_data import load_clef_documents
from load_data import load_relevance_assessments
from text2vec import text2vec_idf_sum
from text2vec import create_text_representations

timer = PPrintTimer()


def _word_overlap(document, query):
    """
    Take items from document and keep if they are in the query, i.e. return intersection of query and document words
    and take the frequency from the document.
    :param document:
    :param query:
    :return:
    """
    document = dict(document)
    query = dict(query)
    return {k: v for k, v in document.items() if k in query.keys()}


def _count_words(document):
    tokens = document
    # {k: v/n_d for k, v in Counter(tokens).items()}
    return dict(Counter(tokens))


def _save_ranking(config_str, all_rankings, base_path):
    """
    Stores ranking in rankings_year_langpair_embspace_aggrMethod.txt file, which is later reused for computing
    ensembled rankings of different aggregations methods.
    :param config_str: csv record string
    :param all_rankings: ranking to be stored
    :param base_path: directory where file should be saved
    :return:
    """
    _, campaign_year, language_pair, embedding_space, aggregation_method, _, _ = config_str.split(";")
    subdir = "rankings_%s_%s_%s_%s.txt" % (campaign_year, language_pair, embedding_space, aggregation_method)
    path = base_path + subdir
    file_content = []
    for query, ranking in all_rankings:
        one_line = str(query) + '; ' + ' '.join(ranking) + "\n"
        file_content.append(one_line)
    file_content = ''.join(file_content)
    with open(path, mode="w") as ranking_file:
        ranking_file.write(file_content)
    pass


def prepare_experiment(doc_dirs, limit_documents, query_file, limit_queries,
                       query_language, relevance_assessment_file):
    """
    Loads documents, evaluation data and queries needed to run different experiments on CLEF data.
    :param doc_dirs: directories containing the corpora for a specific CLEF campaign
    :param limit_documents: for debugging purposes -> limit number of docs loaded
    :param query_file: CLEF Topics (i.e., query) file
    :param limit_queries: for debugging purposes -> limit number of queries loaded
    :param query_language: language of queries
    :param relevance_assessment_file: relevance assesment file
    :return:
    """
    if limit_documents is not None:
        limit_documents -= 1
    documents = []
    doc_ids = []
    limit_reached = False
    for doc_dir, extractor in doc_dirs:
        if not limit_reached:
            for file in next(os.walk(doc_dir))[2]:
                if not file.endswith(".dtd"):
                    tmp_doc_ids, tmp_documents = load_clef_documents(doc_dir + file, extractor, limit_documents)
                    documents.extend(tmp_documents)
                    doc_ids.extend(tmp_doc_ids)
                if len(documents) == limit_documents:
                    limit_reached = True
                    break
    print("Documents loaded %s" % (timer.pprint_lap()))
    relass = load_relevance_assessments(relevance_assessment_file)
    print("Evaluation data loaded %s" % (timer.pprint_lap()))
    query_ids, queries = load_queries(query_file, language_tag=query_language, limit=limit_queries)
    print("Queries loaded %s" % (timer.pprint_lap()))
    return doc_ids, documents, query_ids, queries, relass


def prepare_word_embeddings(query_lang_emb, qlang_long,
                            doc_lang_emb, dlang_long,
                            limit_emb, normalize=False, processes=40):
    """
    Creates Word Embedding Helper Object
    :param query_lang_emb: language of queries
    :param qlang_long: short version
    :param doc_lang_emb: language of documents
    :param dlang_long: short version
    :param limit_emb: load only first n embeddings
    :param normalize: transform to unit vectors
    :param processes: number of parallel workers
    :return:
    """
    embeddings = Embeddings()
    embeddings.load_embeddings(query_lang_emb, processes=processes, language=qlang_long,
                               limit=limit_emb, normalize=normalize)
    embeddings.load_embeddings(doc_lang_emb, processes=processes, language=dlang_long,
                               limit=limit_emb, normalize=normalize)
    return embeddings


def evaluate_clef(query_ids, doc_ids, relass, all_rankings):
    """
    Evaluates results for queries in terms of Mean Average Precision (MAP). Evaluation gold standard is
    loaded from the relevance assessments.
    :param query_ids: internal id of query
    :param doc_ids: internal id of document
    :param relass: gold standard (expected) rankings
    :param all_rankings: (actual) rankings retrieved
    :return:
    """
    average_precision_values = []
    rankings_with_doc_ids = []
    for j in range(len(query_ids)):
        query_id = query_ids[j]  # for the ith query
        if query_id in relass:  # len(relevant_docs) > 0:
            relevant_docs = relass[query_id]
            ranking = all_rankings[j].tolist()  # get ranking for j'th query

            ranking_with_doc_ids = [doc_ids[i] for i in ranking]
            rankings_with_doc_ids.append((query_id, ranking_with_doc_ids))

            is_relevant = [ranked_doc in relevant_docs for ranked_doc in ranking_with_doc_ids]
            ranks_of_relevant_docs = np.where(is_relevant)[0].tolist()
            precisions = []
            for k, rank in enumerate(ranks_of_relevant_docs, 1):
                summand = k / (rank + 1)  # +1 because of mismatch btw. one based rank and zero based indexing
                precisions.append(summand)
            ap = np.mean(precisions)
            average_precision_values.append(ap)
    mean_average_precision = np.mean(np.array(average_precision_values))
    return rankings_with_doc_ids, mean_average_precision


def run_experiment(aggregation_method, query_lang, doc_lang, experiment_data, initialized_embeddings, processes=40):
    """
    Constructs text representations for queries and documents according to the specified aggregation method. From the
    text representations it retrieves for each query the documents and computes the evaluation metric.
    :param aggregation_method:
    :param query_lang:
    :param doc_lang:
    :param experiment_data:
    :param initialized_embeddings:
    :param processes:
    :return:
    """
    # unpacking values
    qlang_short, qlang_long = query_lang
    dlang_short, dlang_long = doc_lang
    doc_ids, documents, query_ids, queries, relass = experiment_data
    embeddings = initialized_embeddings

    doc_arry = create_text_representations(language=dlang_long, id_text=zip(doc_ids, documents),
                                           emb=embeddings, processes=processes, method=aggregation_method,
                                           idf_weighing=aggregation_method == text2vec_idf_sum)
    query_arry = create_text_representations(language=qlang_long, id_text=zip(query_ids, queries),
                                             emb=embeddings, processes=processes, method=aggregation_method,
                                             idf_weighing=False)  # Queries are not idf-scaled
    print("Query- and Document-Embeddings created %s" % (timer.pprint_lap()))

    # keep only documents for which we have a non-zero text embedding, i.e. for which at least one
    # word embedding could exists (filters out empty documents)
    doc_non_zero = np.all(doc_arry != 0, axis=1)
    doc_arry = doc_arry[doc_non_zero]
    doc_ids = list(compress(doc_ids, doc_non_zero))

    index, quantizer = create_index(doc_arry)
    D, I = index.search(query_arry, len(doc_arry))
    print("Retrieval done %s" % (timer.pprint_lap()))

    all_rankings, evaluation_result = evaluate_clef(query_ids=query_ids, doc_ids=doc_ids, relass=relass, all_rankings=I)
    # return all_rankings, evaluation_result
    return evaluation_result


def run(experiment, name, vspace, experiment_count, offset, results, csv_prefix):
    """
    Executes configured experiments and records result in csv record
    :param experiment: configured experiment
    :param name: name of the aggregation method to be used
    :param vspace: name of the method used for inducing shared embedding space
    :param experiment_count: used for resuming experiments
    :param offset: value of experiment_count from where it should be resumed
    :param results: containing caching result csv records
    :param csv_prefix: string that is prefixed to each csv record
    :return:
    """
    if not experiment_count <= offset:
        tmp_timer = PPrintTimer().start()  # Experiment timer
        rankings, eval_score = experiment()
        time = tmp_timer.pprint_stop(suffix=False)
        result = "%s;%s;%s;%s;%s\n" % (str(experiment_count), csv_prefix + vspace, name, time, str(eval_score))
        _save_ranking(result, rankings, c.RESULTS_DIR)
        results.append(result)
        print("\n" + result)
    return experiment_count + 1, results


