from google_drive_downloader import GoogleDriveDownloader as gdd
import os, json
import argparse
import json
import os

import numpy as np
import torch
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
from response_selection.preprocessing import get_dd_corpus, get_syndd_corpus
from response_selection.model import BertSelect, BertSelectforMultitask
from response_selection.utils import load_model, set_random_seed, get_uttr_token
import scipy.stats


def download_Zhao2_dataset(
    daily_url: str = "https://zenodo.org/record/4120039/files/mturk_results.json?download=1",
    daily_output_fname: str = "./annotated/annotated_dd_v2.json",
):
    """Download the dataset with human-annotated score by Zhao et al., ACL 2020
    Args:
        daily_id (str): Annotated DailyDialog dataset ID
        daily_output_fname (str): Path for output
        persona_id (str): Annotated Personachat dataset ID
        persona_output_fname (str): Path for output
    """
    if not os.path.exists(daily_output_fname):
        import wget

        wget.download(daily_url, out=daily_output_fname)


def download_gupta_dataset(
    daily_url: str = "https://raw.githubusercontent.com/prakharguptaz/Adv_gen_dialogue/main/dataset/ret_human_scores.csv",
    daily_output_fname="./annotated/annotated_dd_gupta.csv",
):
    if not os.path.exists(daily_output_fname):
        import wget

        wget.download(daily_url, out=daily_output_fname)


def download_Zhao_dataset(
    daily_id: str = "1tbSnH20B2SRBeqiZTiW7NKhE_EVL6ktw",
    daily_output_fname: str = "./annotated/annotated_dd.json",
):
    """Download the dataset with human-annotated score by Zhao et al., ACL 2020
    Args:
        daily_id (str): Annotated DailyDialog dataset ID
        daily_output_fname (str): Path for output
        persona_id (str): Annotated Personachat dataset ID
        persona_output_fname (str): Path for output
    """
    if not os.path.exists(daily_output_fname):
        gdd.download_file_from_google_drive(
            daily_id, daily_output_fname, unzip=False
        )


def merge_annotated_files(
    output_fname="./annotated/merged_annotated_dd.json",
):
    with open("./annotated/annotated_dd.json", "r") as f:
        data1 = json.load(f)
    with open("./annotated/annotated_dd_v2.json", "r") as f:
        data2 = json.load(f)

    data = []
    cids = []
    for k, v in data1.items():
        cids.append(int(k))
        context = v["context"]
        reference = v["reference"][1]
        for c in context:
            assert len(c) == 2
        context = "[UTTR]".join([c[1] for c in context])
        assert "ground-truth" in v["responses"].keys()
        for speaker, el in v["responses"].items():
            if speaker in ["ground-truth", "negative-sample"]:
                continue
            response = el["uttr"]
            score = sum([_["overall"] for _ in el["scores"].values()]) / len(
                el["scores"]
            )
            score = (score - 1) / 4
            data.append(
                {
                    "context": context,
                    "response": response,
                    "score": score,
                    "reference": reference,
                }
            )
    for k, v in data2.items():
        assert k not in cids
        cids.append(int(k))
        context = v["context"]
        reference = v["reference"][1]
        for c in context:
            assert len(c) == 2
        context = "[UTTR]".join([c[1] for c in context])
        assert "ground-truth" in v["responses"].keys()
        for speaker, el in v["responses"].items():
            if "ground-truth" == speaker:
                continue
            response = el["text"]
            score = [
                a["sensible"]
                for a in el["scores"].values()
                if not a["sensible_is_outlier"]
            ]

            score = sum(score) / len(score)
            data.append(
                {
                    "context": context,
                    "response": response,
                    "score": score,
                    "reference": reference,
                }
            )

    import csv

    with open("./annotated/annotated_dd_gupta.csv", "r") as f:
        spamreader = csv.reader(f)
        lens = None
        for e in spamreader:
            if lens is None:
                lens = len(e)
                continue
            assert len(e) == lens
            context = e[3].replace("<eot>", "[UTTR]")
            response = e[4]
            score = float(e[-3])
            assert 0 <= score <= 1
            reference = e[-2]
            data.append(
                {
                    "context": context,
                    "response": response,
                    "score": score,
                    "reference": reference,
                }
            )
    with open(output_fname, "w") as f:
        json.dump(data, f)


def read_annotated_json(
    daily_output_fname: str = "./annotated/merged_annotated_dd.json",
):
    with open(daily_output_fname, "r") as f:
        data = json.load(f)
    return data


def run_evaluation(eval_dataset):
    set_random_seed(42)

    device = torch.device("cuda")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    UTTR_TOKEN = get_uttr_token()

    special_tokens_dict = {"additional_special_tokens": [UTTR_TOKEN]}
    tokenizer.add_special_tokens(special_tokens_dict)

    model_list = []

    bert = BertModel.from_pretrained("bert-base-uncased")
    bert.resize_token_embeddings(len(tokenizer))
    if not args.pred_neg_type:
        model = BertSelect(bert)
    else:
        print("load model for multi-task learning")
        model = BertSelectforMultitask(bert)
    model = load_model(
        model,
        args.model_path.format(args.random_seed),
        args.t_epoch,
        len(tokenizer),
    )
    model.eval()
    model.to(device)

    humanscore_list = []
    model_list = []
    for sample in tqdm(eval_dataset):
        context, response, score = (
            sample["context"],
            sample["response"],
            sample["score"],
        )
        tokenized = tokenizer(context, response, return_tensors="pt")
        input_ids = tokenized["input_ids"].to(device)
        mask = tokenized["attention_mask"].to(device)
        with torch.no_grad():
            model_output = model(input_ids, mask)
            if len(model_output) == 2 and isinstance(model_output, tuple):
                model_output = model_output[0]
            output = float(model_output.cpu().numpy()[0])
            humanscore_list.append(score)
            model_list.append(output)

    # 0:correlation 1:p-value
    pearson = scipy.stats.pearsonr(humanscore_list, model_list)
    spearman = scipy.stats.spearmanr(humanscore_list, model_list)
    print(f"Pearson: {pearson}")
    print(f"Spearman: {spearman}")

    with open(args.output_fname, "w") as f:
        json.dump({"pearson": pearson, "spearman": spearman}, f)


def main(args):
    download_Zhao_dataset()
    download_Zhao2_dataset()
    download_gupta_dataset()
    merge_annotated_files()
    eval_dataset = read_annotated_json()
    run_evaluation(eval_dataset)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process some integers.")
    parser.add_argument("--corpus", default="dd", choices=["persona", "dd"])
    parser.add_argument(
        "--setname", default="test", choices=["valid", "test"]
    )
    parser.add_argument("--log_path", type=str, default="reseval_result")
    parser.add_argument(
        "--d_type",
        type=str,
        default="gpt",
        choices=["random", "human", "gpt", "syn"],
    )
    parser.add_argument(
        "--approach",
        type=str,
        default="direct_w_ans",
        choices=[
            "none",
            "bm25",
            "maskandfill",
            "kwsim",
            "direct_w_ans",
            "direct_wo_ans",
            "direct_0_shot",
            "meta",
        ],
    )
    parser.add_argument("--t_epoch", type=int, default=0)
    parser.add_argument(
        "--model_path",
        type=str,
        default="./logs/{}_{}_batch32_candi{}_seed{}/model",
    )
    parser.add_argument(
        "--train_num_candidates",
        type=int,
        default=11,
        help="total number of candidates(pos + negs) used during training",
    )

    parser.add_argument(
        "--pred_neg_type",
        type=str,
        default="False",
        choices=["True", "False"],
    )
    parser.add_argument(
        "--pred_neg_type_alpha", type=float, default=1.0,
    )

    parser.add_argument(
        "--random_seed",
        type=int,
        default=42,
        help="random seed during training",
    )

    args = parser.parse_args()
    args.pred_neg_type = args.pred_neg_type == "True"
    if args.pred_neg_type:
        args.model_path = args.model_path.format(
            args.d_type,
            args.approach,
            args.train_num_candidates,
            str(args.random_seed)
            + f"_predneg{str(args.pred_neg_type_alpha)}",
        )
    else:
        args.model_path = args.model_path.format(
            args.d_type,
            args.approach,
            args.train_num_candidates,
            args.random_seed,
        )
    assert len(args.model_path.split("/")) == 4

    args.exp_name = f"{args.d_type}_{args.approach}_{args.setname}"
    if args.pred_neg_type:
        args.exp_name += "_predneg{}".format(args.pred_neg_type_alpha)

    args.log_path = os.path.join(args.log_path, args.corpus)

    os.makedirs(args.log_path, exist_ok=True)
    args.output_fname = os.path.join(args.log_path, args.exp_name) + ".json"
    assert not os.path.exists(args.output_fname)
    print("\n", args.output_fname, "\n")
    main(args)
