import os
import sys
sys.path.append("/home/[USER]/workshop/wikihow")
os.chdir("/home/[USER]/workshop/wikihow")
import json
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from collections import defaultdict
from tqdm import tqdm
import re
import argparse
import logging
# INFO, DEBUG, WARNING, CRITICAL
logging.basicConfig(
    level=logging.WARNING,
    format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
            logging.StreamHandler()
        ]
)
logger = logging.getLogger(__name__)

steps = ['Figure out what time you need to wake up in the morning.', 'Repeat your mantra.',
         'Cover a medium-sized baking tray with foil.',
         'Do a range check... with the antenna on the transmitter collapsed, walk 50–100 feet (15.2–30.5 m) away and check that your control surfaces still respond without chattering or unwanted movements.',
         'Confront mean girls using the divide-and-conquer approach.', 'Sign up to offer your music on iTunes.',
         'Brush the hair on your crown.', 'Share news and updates about phishing attacks with your Facebook friends.',
         'Plan a "set" (a list of songs to be played) before setting foot inside a DJ booth.']


INDEX = 'wikihow_en'
DEBUG = False
FROM_SCRATCH = False

def get_steps(step_list):
    steps = [f"{x['headline']} {x['description']}" for x in step_list]
    # clean step
    steps = [re.sub(r'\{.*?(http|png|jpg|www)+.*?\}', '', x).strip() for x in steps]
    steps = [re.sub(r'\{.*?(smallUrl).*(<p>)', '', x).strip() for x in steps]
    steps = [re.sub(r'<a.*?<.*?>', '', x).strip() for x in steps]  # a href
    steps = [x.strip() for x in steps if x.strip()]
    return steps


def get_page_content(d):
    if d["methods"] != []:
        section = "methods"
        assert "steps" not in d or d["steps"] == []
    elif "steps" in d and d["steps"] != []:
        section = "steps"
    else:
        return ""

    task_name = d["title"][7:] if d["title"][:7] == "How to " else d["title"]
    steps = [task_name.capitalize() + "."]
    if section == "methods":
        for m_idx, method in enumerate(d["methods"]):
            cur_steps = get_steps(step_list=method["steps"])
            steps += cur_steps

    elif section == "steps":
        steps += get_steps(step_list=d["steps"])

    return " ".join(steps)

class SESearch:
    def __init__(self):
        self.es = Elasticsearch(timeout=60, host='metis.lti.cs.cmu.edu')
        self.dpath = "./data/wikihow/js_files_en/"

        if FROM_SCRATCH:
            self.es.indices.delete(index=INDEX, ignore=[400, 404])
            logger.info(f"delete {INDEX}")
            self.es.indices.create(index=INDEX)
            logger.info(self.es.indices.get_alias().keys())
            self.create_index()

        logger.info(f"Done init the index")

    def gendata(self):
        tot = 0
        for root, subdirs, files in os.walk(self.dpath):
            for file in files:
                full_path = os.path.join(root, file)
                with open(full_path, "r", encoding="utf-8") as f:
                    src = json.load(f)
                title = src['title']
                title = title[7:] if title[:7] == 'How to ' else title
                # creat body
                text = get_page_content(src)
                if text.strip() == "":
                    continue

                result = {
                    "_index": INDEX,
                    "_type": "_doc",
                    'source_content': text,
                    'full_path': full_path,
                    'title': title
                }
                tot += 1
                if tot % 10000 == 0:
                    logger.info(f"index {tot} data now")
                yield result

    def create_index(self):
        all_docs = list(self.gendata())
        logger.info(bulk(self.es, all_docs, index=INDEX))


class CESearch:
    def __init__(self, args):
        self.searcher = SESearch()
        self.args = args
        self.index_name = INDEX
        self.step_file = self.args.step_file
        self.save_file = self.args.save_file
        # self.step_file = "./data/wikihow/all_steps.json"
        # self.save_file = "./data/wikihow/goal_candidate.ir.json"

    def get_topk(self, query_str: str, field: str, topk: int = 5):
        results = self.searcher.es.search(
            index=self.index_name,
            body={'query': {'match': {field: query_str}}, 'size': topk})['hits']['hits'][:topk]
        return results

    def load_steps(self):
        # with open(self.step_file, "r", encoding="utf-8") as f:
        #     d = json.load(f)
        # steps = [x['step'] for x in d]
        # return steps
        steps = {}
        with open(self.step_file, "r", encoding="utf-8") as f:
            for line in f:
                tks = line.strip().split("\t")
                steps[tks[0]] = tks[1]
        print(f"{len(steps)} steps in total")
        return steps


    def search_wikihow(self):
        # load step goal
        steps = self.load_steps()
        d = defaultdict(dict)
        recall = {1: 0, 10: 0, 30: 0}
        mrr = 0
        for step, linked_goal in steps.items():
            results = self.get_topk(step, field=self.args.source_field, topk=30)
            cur_rank = 100
            for idx, result in enumerate(results):
                d[step][result['_source']['title']] = {'bm25': str(result["_score"])}
                if result['_source']['title'] == linked_goal:
                    cur_rank = idx + 1

            if cur_rank != 100:
                mrr += 1 / cur_rank

            recall = {k: v + int(cur_rank <= k) for k, v in recall.items()}

        recall = {k: f"{v}/{len(steps)}={v/len(steps)}" for k, v in recall.items()}
        mrr /= len(steps)

        for k, v in recall.items():
            print(f"recall@{k}={v}")
        print(mrr)

        d['recall'] = recall
        d['mrr'] = mrr
        with open(self.save_file, "w+", encoding="utf-8") as f:
            json.dump(d, f, indent=2)

        logger.info(f"Done search {len(steps)}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--step_file')
    parser.add_argument('--save_file')
    parser.add_argument('--source_field', choices=('source_content', 'title'), default='source_content')
    args = parser.parse_args()
    args.save_file = f"./data/wikihow/arxiv/wikihow_ground_truth/gold.para.{args.step_file}.{args.source_field}.dev.t30.ir"
    args.step_file = f"./data/wikihow/arxiv/wikihow_ground_truth/gold.para.{args.step_file}.dev.txt"
    print(vars(args))
    search = CESearch(args)
    search.search_wikihow()
    # with open("./data/wikihow/js_files_en/9997106_detect-ransomware-on-iphone-or-ipad.json", "r") as f:
    #     d = json.load(f)
    #     print(get_page_content(d))