from experiment_clef import prepare_experiment as prepare_clef_experiment
from experiment_clef import run_experiment as run_clef_experiment
from embeddings import Embeddings
import constants as c
from text2vec import text2vec_idf_sum
from collection_extractors import *
import os
import sys
import argparse


languages = [("de", "german"), ("en", "english"), ("ru", "russian"), ("fi", "finnish"), ("it", "italian")]
short2pair = {elem[0]: elem for elem in languages}
long2pair = {elem[1]: elem for elem in languages}

# Prepare dutch CLEF data
nl_all = (c.PATH_BASE_DOCUMENTS + "dutch/all/", extract_dutch)
dutch = {"2001": [nl_all], "2002": [nl_all], "2003": [nl_all]}

# Prepare italian CLEF data
it_lastampa = (c.PATH_BASE_DOCUMENTS + "italian/la_stampa/", extract_italian_lastampa)
it_sda94 = (c.PATH_BASE_DOCUMENTS + "italian/sda_italian/", extract_italian_sda9495)
it_sda95 = (c.PATH_BASE_DOCUMENTS + "italian/agz95/", extract_italian_sda9495)
italian = {"2001": [it_lastampa, it_sda94],
           "2002": [it_lastampa, it_sda94],
           "2003": [it_lastampa, it_sda94, it_sda95]}

# Prepare finnish CLEF data
aamu9495 = c.PATH_BASE_DOCUMENTS + "finnish/aamu/"
fi_ammulethi9495 = (aamu9495, extract_finish_aamuleth9495)
finnish = {"2001": None, "2002": [fi_ammulethi9495], "2003": [fi_ammulethi9495]}

# Prepare english CLEF data
gh95 = (c.PATH_BASE_DOCUMENTS + "english/GH95/", extract_english_gh)
latimes = (c.PATH_BASE_DOCUMENTS + "english/latimes/", extract_english_latimes)
english = {"2001": [gh95, latimes],
           "2002": [gh95, latimes],
           "2003": [gh95, latimes]}

# Prepare german CLEF data
der_spiegel = (c.PATH_BASE_DOCUMENTS + "german/der_spiegel/", extract_german_derspiegel)
fr_rundschau = (c.PATH_BASE_DOCUMENTS + "german/fr_rundschau/", extract_german_frrundschau)
de_sda94 = (c.PATH_BASE_DOCUMENTS + "german/sda94/", extract_german_sda)
de_sda95 = (c.PATH_BASE_DOCUMENTS + "german/sda95/", extract_german_sda)
german = {"2003": [der_spiegel, fr_rundschau, de_sda94, de_sda95]}

# Prepare russian CLEF data
xml = (c.PATH_BASE_DOCUMENTS + "russian/xml/", extract_russian)
russian = {"2003": [xml]}


_all = {"nl": dutch, "it": italian, "fi": finnish, "en": english,
        "de": german, "ru": russian}

query_limit = None  # limit for testing/debugging,e.g. 10
doc_limit = None  # limit for testing/debugging, e.g. 100
year = "2003"
processes = 20
embeddings = Embeddings()


def run_compact(params):
  return run(**params)


def run(query_lang,  doc_lang,
        path_query_embeddings="", path_query_vocab="",
        path_doc_embeddings="", path_doc_vocab=""):
  stdout = sys.stdout
  f = open(os.devnull, 'w')
  sys.stdout = f

  src_lang = long2pair[query_lang] if len(query_lang) != 2 else short2pair[query_lang]
  tgt_lang = long2pair[doc_lang] if len(doc_lang) != 2 else short2pair[doc_lang]

  if path_query_vocab:
    embeddings.load_serialized_embeddings(path_query_vocab, path_query_embeddings, src_lang[1])
    embeddings.load_serialized_embeddings(path_doc_vocab, path_doc_embeddings, tgt_lang[1])

  current_path_queries = c.PATH_BASE_QUERIES + year + "/Top-" + src_lang[0] + year[-2:] + ".txt"
  current_path_documents = _all[tgt_lang[0]][year]
  current_assessment_file = c.PATH_BASE_EVAL + year + "/qrels_" + tgt_lang[1]
  current_experiment_data = prepare_clef_experiment(current_path_documents, doc_limit, current_path_queries,
                                                    query_limit, query_lang, current_assessment_file)
  evaluation_result = run_clef_experiment(text2vec_idf_sum, query_lang=src_lang, doc_lang=tgt_lang,
                                          experiment_data=current_experiment_data,
                                          processes=processes, initialized_embeddings=embeddings)

  sys.stdout = stdout
  print(evaluation_result)
  return evaluation_result


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--query_lang", type=str, required=True, choices=short2pair.keys())
  parser.add_argument("--path_query_vocab", type=str, default="")
  parser.add_argument("--path_query_embeddings", type=str, default="")

  parser.add_argument("--doc_lang", type=str, required=True, choices=short2pair.keys())
  parser.add_argument("--path_doc_vocab", type=str, default="")
  parser.add_argument("--path_doc_embeddings", type=str, default="")

  args = parser.parse_args()

  run(query_lang=args.query_lang,
      path_query_embeddings=args.path_query_embeddings,
      path_query_vocab=args.path_query_vocab,
      doc_lang=args.doc_lang,
      path_doc_embeddings=args.path_doc_embeddings,
      path_doc_vocab=args.path_doc_vocab)
