import os
import pickle
import argparse
import pandas as pd
from gensim.models import KeyedVectors

from utils.zeroshot_utils import TextPreprocessor, Evaluator
from models.lexical_models import TFIDFRetriever, BM25Retriever, SWSNRetriever
from models.zeroshot_dense_models import Word2vecRetriever, FasttextRetriever, BERTRetriever



def display_scores(all_ground_truths, all_results, precision_range=[], recall_range=[20,50,100]):
    metrics = list()
    scores = list()
    evaluator = Evaluator()
    
    # Compute MAP and MRR.
    metrics.extend(['MAP', 'MRR'])
    scores.append(evaluator.compute_mean_score(evaluator.average_precision, all_ground_truths, all_results))
    scores.append(evaluator.compute_mean_score(evaluator.reciprocal_rank, all_ground_truths, all_results))

    # Compute Recall@k.
    for k in recall_range:
        metrics.append('R@'+str(k))
        scores.append(evaluator.compute_mean_score(evaluator.recall, all_ground_truths, all_results, at=k))

    # Compute Precision@k.
    for k in precision_range:
        metrics.append('P@'+str(k))
        scores.append(evaluator.compute_mean_score(evaluator.precision, all_ground_truths, all_results, at=k))

    # Print scores.
    metrics_format = len(metrics) * '{:<10}'
    scores_format = len(scores) * '{:<10.2%}'
    print(metrics_format.format(*metrics))
    print(scores_format.format(*scores))


def main(args):
    print("Loading questions and articles...")
    dfA = pd.read_csv(args.articles_path)
    dfQ_test = pd.read_csv(args.test_questions_path)

    if not args.retriever_model == 'bert':
        print("Preprocessing articles and questions (lemmatizing={})...".format(args.lem))
        cleaner = TextPreprocessor(spacy_model="fr_core_news_md")
        articles = cleaner.preprocess(dfA['article'], lemmatize=args.lem)
        questions = cleaner.preprocess(dfQ_test['question'], lemmatize=args.lem)
    else:
        articles = dfA['article'].tolist()
        questions = dfQ_test['question'].tolist()

    print("Initializing the {} retriever model...".format(args.retriever_model))
    if args.retriever_model == 'tfidf':
        retriever = TFIDFRetriever(retrieval_corpus=articles)
    elif args.retriever_model == 'bm25':
        retriever = BM25Retriever(retrieval_corpus=articles, k1=1.2, b=0.75)
    elif args.retriever_model == 'word2vec':
        retriever = Word2vecRetriever(model_path_or_name='embeddings/word2vec/lemmatized/word2vec_frWac_lem_skipgram_d500.bin', pooling_strategy='mean', retrieval_corpus=articles)
    elif args.retriever_model == 'fasttext':
        retriever = FasttextRetriever(model_path_or_name='embeddings/fasttext/fasttext_frCc_cbow_d300.bin', pooling_strategy='mean', retrieval_corpus=articles)
    elif args.retriever_model == 'bert':
        retriever = BERTRetriever(model_path_or_name='camembert-base', pooling_strategy='mean', retrieval_corpus=articles)

    print("Running model on test questions...")
    if args.retriever_model == 'tfidf' or args.retriever_model == 'bm25':
        results = retriever.search_all(questions, top_k=100)
    else:
        results = retriever.search_all(questions, top_k=100, dist_metric='cosine')

    print("Displaying the resulst...")
    ground_truths = dfQ_test['article_ids'].apply(lambda x: list(map(int, x.split(',')))).tolist()
    display_scores(ground_truths, results)

    print("Saving the results to {} ...".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, '{}'.format(args.retriever_model) + '.pkl'), 'wb') as f:
        pickle.dump(results, f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--articles_path", 
                        type=str, 
                        default="../../data/final/articles_fr.csv",
                        help="Path of the data file containing the law articles."
    )
    parser.add_argument("--test_questions_path", 
                        type=str, 
                        default="../../data/final/questions_fr_test.csv",
                        help="Path of the data file containing the test questions."
    )
    parser.add_argument("--lem",
                        action='store_true', 
                        default=False,
                        help="Lemmatize the questions and articles for retrieval."
    )
    parser.add_argument("--retriever_model", 
                        type=str,
                        choices=["tfidf","bm25","word2vec","fasttext","bert"],
                        required=True,
                        help="The type of model to use for retrieval"
    )
    parser.add_argument("--output_dir",
                        type=str, 
                        default="./output/zeroshot/test-run/",
                        help="Path of the output directory."
    )
    args, _ = parser.parse_known_args()
    main(args)
