"""
Usage:
python3 qa_browser.py --share
"""

import argparse
import datetime
import json
import os
import random
import re
import time
from pathlib import Path
from typing import Dict, List, Tuple

import gradio as gr
import pandas as pd
from pydantic import BaseModel

from fastchat.llm_judge.common import (
    load_model_answers,
    load_questions,
)
from fastchat.utils import (
    build_logger,
)

block_css = """
.user {
    background-color: #06611e;
}
.system {
    border-left: 2px solid #0066ff;
    padding-left: 5px;
}
.turn_hint {
    color: #0066ff;
    font-size: 2.5em;
}
#reference {
    background-color: #4d493d;
}

@media (prefers-color-scheme: light) {
    .user {
    background-color: #DEEBF7;
    }
    .turn_hint {
        color: #0066ff;
        font-size: 2.5em;
    }
    #reference {
       background-color: #FFF2CC;
    }

}
"""


class Vote(BaseModel):
    question_id: int | str
    model_a: str
    model_b: str
    turn: int
    language: str
    category: str
    judge: str = "expert"
    winner: str = ""
    timestamp: float = 0
    user_ip: str = ""


class MultilingualText(BaseModel):
    # EN: str
    DE: str
    FR: str
    IT: str
    ES: str


class AppTexts(BaseModel):
    user_instruction: MultilingualText
    lang_label: MultilingualText
    lang_label_to_lang_code: Dict[str, str]

    def get_lang_code(self, lang_label: str) -> str:
        return self.lang_label_to_lang_code[lang_label]


class State(BaseModel):
    ratings: List[Vote] = []
    questions: Dict[str, List[Dict]]
    model_answers: Dict[str, Dict[str, Dict[str, Dict]]]
    model_names: Dict[str, List[str]]
    categories: List[str]
    texts: AppTexts
    is_single_eval: bool
    blacklist: set[str] = set()


# LOGFILE = "human_eval/logs/crowd_logs.jsonl"
LOGDIR = "human_eval/logs/24EU_bactrianx_pair_wise.jsonl"

logger = build_logger("gradio_web_server", "gradio_web_server.log")


def get_numerical_vote_fun(vote: int):
    def numerical_vote(state: State, request: gr.Request):
        logger.info(f"Voted {vote} ip: {request.client.host}")
        vote_last_response(state, str(vote), request)
        return state

    return numerical_vote


def a_is_better(state: State, request: gr.Request):
    logger.info(f"a_is_better. ip: {request.client.host}")
    vote_last_response(state, "A", request)
    return state


def b_is_better(state: State, request: gr.Request):
    logger.info(f"b_is_better. ip: {request.client.host}")
    vote_last_response(state, "B", request)
    return state


def tie(state: State, request: gr.Request):
    logger.info(f"tie. ip: {request.client.host}")
    vote_last_response(state, "Tie", request)
    return state


def both_bad(state: State, request: gr.Request):
    logger.info(f"tie (both bad). ip: {request.client.host}")
    vote_last_response(state, "Tie (both bad)", request)
    return state


def vote_last_response(state: State, vote: str, request: gr.Request):
    with open(LOGFILE, "a") as fout:
        rating: Vote = state.ratings[-1]
        rating.winner = vote
        rating.timestamp = round(time.time(), 4)
        rating.user_ip = request.client.host
        fout.write(json.dumps(rating.model_dump()) + "\n")


def get_conv_log_filename(is_single_eval: bool):
    return 


def get_random_or_current_question(state, lang, current_question_preview, new_turn_idx):
    current_question = list(filter(lambda x: x["preview"] == current_question_preview, state.questions[lang]))[0]
    if new_turn_idx == 0:
        q_ids, weights = get_weights_questions(state.questions[lang], blacklisted_ids=state.blacklist)
        random_question_id = random.choices(q_ids, weights=weights, k=1)[0]
        state.blacklist.add(random_question_id)
        random_question = [q for q in state.questions[lang] if q["question_id"] == random_question_id][0]
        return random_question
    else:
        return current_question


def display_pairwise_answer(state: State, model_selector1, model_selector2, question_selector, language_selector):
    if len(state.ratings) == 0:
        last_turn_idx = 1
    else:
        last_turn_idx = state.ratings[-1].turn
    new_turn_idx = (last_turn_idx + 1) % 2
    # keep language the same
    lang = language_selector

    # get random new question
    question = get_random_or_current_question(state, lang, question_selector, new_turn_idx)
    question_selector = get_question_dropdown(state, lang, value=question["preview"], use_update=True)

    # set new selected models
    random_model_side = random.randint(0, 1)
    model_selector1 = get_model_dropdown(state, lang, selected_idx=random_model_side, use_update=True)
    model_selector2 = get_model_dropdown(state, lang, selected_idx=1 - random_model_side, use_update=True)

    # get answers from selected models
    model_name1 = model_selector1.constructor_args["value"]
    model_name2 = model_selector2.constructor_args["value"]
    q_id = question["question_id"]

    ans1 = state.model_answers[lang][model_name1][q_id]
    ans2 = state.model_answers[lang][model_name2][q_id]

    # prepare rating for this new question
    lang_code = state.texts.get_lang_code(lang)
    state.ratings.append(
        Vote(
            question_id=q_id,
            model_a=model_name1,
            model_b=model_name2,
            turn=new_turn_idx,
            language=lang_code,
            category=question["category"],
        )
    )

    return [
        state,
        model_selector1,
        model_selector2,
        question_selector,
        language_selector,
    ] + pairwise_to_gradio_chat_mds(state, question, ans1, ans2, language_selector, turn_idx=new_turn_idx)


def get_turn_hint_text(turn_number: int):
    return f"<span style='color:blue;font-weight:700;font-size:1.5em'> > Rate only turn {turn_number}{' independent from previous turns' if turn_number > 1 else ''}</span>"


def pairwise_to_gradio_chat_mds(state, question, ans_a, ans_b, language_selector, turn_idx=None):
    end = len(question["turns"]) if turn_idx is None else turn_idx + 1

    num_sides = 2
    num_turns = 2
    mds = [""] * (num_turns * num_sides + num_turns)
    for i in range(end):
        base = i * 3
        if i == 0:
            mds[base + 0] = "##### User\n" + question["turns"][i]
        else:
            mds[base + 0] = "##### User's follow-up question \n" + question["turns"][i]
        mds[base + 1] = f"##### Assistant A - Turn {i+1}\n" + post_process_answer(
            ans_a["choices"][0]["turns"][i].strip()
        )
        mds[base + 2] = f"##### Assistant B - Turn {i+1}\n " + post_process_answer(
            ans_b["choices"][0]["turns"][i].strip()
        )

    ref = question.get("reference", ["", ""])
    if turn_idx is None:
        if ref[0] != "" or ref[1] != "":
            mds.append(f"##### Reference Solution\nQ1. {ref[0]}\nQ2. {ref[1]}")
    else:
        x = ref[turn_idx] if turn_idx < len(ref) else ""
        if x:
            mds.append(f"##### Reference Solution - Turn {turn_idx+1}\n{ref[turn_idx]}")
        else:
            mds.append("")

    mds = [get_turn_hint_text(end)] + mds + [get_turn_hint_text(end)]
    mds = [state.texts.user_instruction.model_dump()[state.texts.get_lang_code(language_selector)]] + mds
    return mds


def display_single_answer(state: State, model_selector1, question_selector, language_selector):
    random_turn_idx = random.randint(0, 1)
    lang = language_selector
    random_question = get_random_or_current_question(state, lang, question_selector)
    q_id = random_question["question_id"]
    model_name1 = model_selector1
    model_name2 = model_selector2
    ans1 = state.model_answers[lang][model_name1][q_id]

    question_selector = get_question_dropdown(state, lang, value=random_question["preview"], use_update=True)
    model_selector1 = get_model_dropdown(state, lang, 0, use_update=True)
    model_selector2 = get_model_dropdown(state, lang, 1, use_update=True)

    # prepare rating for this new question
    lang_code = state.texts.get_lang_code(lang)
    state.ratings.append(
        Vote(
            question_id=q_id,
            model_a=model_name1,
            model_b=model_name2,
            turn=random_turn_idx + 1,
            language=lang_code,
            category=random_question["category"],
        )
    )
    return [
        state,
        model_selector1,
        question_selector,
        language_selector,
    ] + single_to_gradio_chat_mds(
        state, random_question, ans1, model_name1, language_selector, turn_idx=random_turn_idx
    )


def single_to_gradio_chat_mds(state, question, ans_a, model_name1, language_selector, turn_idx=None):
    end = len(question["turns"]) if turn_idx is None else turn_idx + 1

    num_sides = 1
    num_turns = 2
    mds = [""] * (num_turns * num_sides + num_turns)
    for i in range(end):
        base = i * 2
        if i == 0:
            mds[base + 0] = "##### User\n" + question["turns"][i]
        else:
            mds[base + 0] = "##### User's follow-up question \n" + question["turns"][i]
        mds[base + 1] = f"##### Assistant A - Turn {i+1}\n" + post_process_answer(
            ans_a["choices"][0]["turns"][i].strip()
        )

    mds = [get_turn_hint_text(end)] + mds + [get_turn_hint_text(end)]
    mds = [state.texts.user_instruction.model_dump()[state.texts.get_lang_code(language_selector)]] + mds
    return mds


newline_pattern1 = re.compile("\n\n(\d+\. )")
newline_pattern2 = re.compile("\n\n(- )")


def post_process_answer(x):
    """Fix Markdown rendering problems."""
    x = x.replace("\u2022", "- ")
    x = re.sub(newline_pattern1, "\n\g<1>", x)
    x = re.sub(newline_pattern2, "\n\g<1>", x)
    return x


def get_question_dropdown(state: State, language: str, use_update: bool, value: None | str = None):
    choices = [q["preview"] for q in state.questions[language]]
    kwargs = dict(
        choices=choices, label="Question", value=choices[0] if value is None else value, container=False, visible=False
    )
    if False:  # use_update:
        return gr.Dropdown.update(**kwargs)
    else:
        return gr.Dropdown(**kwargs)


def get_model_dropdown(state: State, language: str, use_update: bool, selected_idx: int = 0):
    choices = state.model_names[language]
    kwargs = dict(
        choices=choices, value=choices[selected_idx], interactive=True, show_label=False, container=False, visible=False
    )
    if False:  # use_update:
        return gr.Dropdown.update(**kwargs)
    return gr.Dropdown(**kwargs)


def build_demo(is_single_eval: bool, model_list: List[str]):
    model_selectors = []
    with gr.Blocks(
        title="MT-Bench-X Evaluator",
        theme=gr.themes.Base(text_size=gr.themes.sizes.text_lg),
        css=block_css,
    ) as demo:
        state = gr.State(get_data(is_single_eval=is_single_eval, model_list=model_list))

        default_language = state.value.texts.lang_label.DE
        default_lang_code = state.value.texts.get_lang_code(default_language)

        language_selector = gr.Dropdown(
            choices=list(state.value.texts.lang_label.model_dump().values()),
            label="Language",
            value=default_language,
            container=False,
            visible=False,
        )

        num_turns = 2
        num_sides = 2 if not is_single_eval else 1

        with gr.Row():
            question_selector = get_question_dropdown(state.value, default_language, use_update=False)

        with gr.Row(elem_id="model_selector_row"):
            for side_idx in range(num_sides):
                model_selectors.append(
                    get_model_dropdown(state.value, default_language, selected_idx=side_idx, use_update=False)
                )

        # Conversation
        user_instruction = gr.Markdown(state.value.texts.user_instruction.model_dump()[default_lang_code])
        turn_hint = gr.Markdown(elem_classes="turn_hint")
        gr.HTML("<hr>")
        chat_mds = []
        # with gr.Row():
        #     for side_idx in range(num_sides):
        #         chat_mds.append(gr.Markdown(elem_id=f"model_name_{side_idx+1}"))
        for i in range(num_turns):
            chat_mds.append(gr.Markdown(elem_classes="user"))
            with gr.Row():
                for j in range(num_sides):
                    with gr.Column(scale=100):
                        chat_mds.append(gr.Markdown(elem_classes="system"))

        chat_mds.append(gr.Markdown(elem_id=f"reference"))
        gr.HTML("<hr>")
        turn_hint2 = gr.HTML(elem_classes="turn_hint")
        chat_mds = [user_instruction, turn_hint] + chat_mds + [turn_hint2]
        with gr.Group():
            display_fun = display_pairwise_answer if not is_single_eval else display_single_answer
            skip_params = (
                display_fun,
                [state] + model_selectors + [question_selector, language_selector],
                [state] + model_selectors + [question_selector, language_selector] + chat_mds,
            )
            # Buttons

            if not is_single_eval:
                with gr.Row():
                    set_pair_wise_compare_buttons(state, skip_params)
            else:
                gr.HTML('<p style="text-align:left;">Worst<span style="float:right;">Best</span></p>')
                with gr.Row():
                    set_single_compare_buttons(state, skip_params)

        # not needed for anonymous eval
        # for model_selector in model_selectors:
        #     model_selector.change(*skip_params)
        # question_selector.change(*skip_params)
        # language_selector.change(*skip_params)
        demo.load(*skip_params)

    return demo


def set_pair_wise_compare_buttons(state: gr.State, skip_params: Tuple):
    a_is_better_btn = gr.Button(value="👈  A is better", interactive=True)
    b_is_better_btn = gr.Button(value="👉  B is better", interactive=True)
    tie_btn = gr.Button(value="🤝  Tie", interactive=True)
    both_bad_btn = gr.Button(value="👎  Both bad", interactive=True)
    skip_btn = gr.Button(value="⏩  Skip", interactive=True)
    a_is_better_btn.click(
        a_is_better,
        inputs=[state],
        outputs=[state],
    ).then(*skip_params)
    b_is_better_btn.click(
        b_is_better,
        inputs=[state],
        outputs=[state],
    ).then(*skip_params)
    tie_btn.click(
        tie,
        inputs=[state],
        outputs=[state],
    ).then(*skip_params)
    both_bad_btn.click(
        both_bad,
        inputs=[state],
        outputs=[state],
    ).then(*skip_params)
    skip_btn.click(*skip_params)


def set_single_compare_buttons(state: gr.State, skip_params: Tuple):
    for idx in range(1, 11):
        gr.Button(value=f"{idx}", interactive=True, min_width=1, elem_id=f"button_{idx}").click(
            get_numerical_vote_fun(idx),
            inputs=[state],
            outputs=[state],
        ).then(*skip_params)
    gr.Button(value="⏩  Skip", interactive=True).click(*skip_params)


def fastchat_dataset_to_question_dataset(questions: List[Dict]) -> List[Dict]:
    if "turn" not in questions[0].keys():
        new_questions = []
        for entry in questions:
            new_entry = {"question_id": entry.get("id"), "category": "val dataset"}
            turns = []
            references = []
            for turn in entry.get("conversations", []):
                if turn["from"] == "human":
                    turns.append(turn["value"])
                else:
                    references.append(turn["value"])
            new_entry["turns"] = turns
            new_entry["reference"] = turns
            new_questions.append(new_entry)
        questions = new_questions
    return questions


def filter_common_qa(questions: List[Dict], model_answers: Dict[str, Dict[str, Dict]]):
    model_answers_ids = [set(answers.keys()) for answers in model_answers.values()]
    common_model_answers_ids = set(model_answers_ids[0])
    for ids in model_answers_ids:
        common_model_answers_ids = common_model_answers_ids.intersection(ids)

    questions_ids = set([question["question_id"] for question in questions])

    common_ids = questions_ids.intersection(common_model_answers_ids)
    new_questions = list(filter(lambda x: x["question_id"] in common_ids, questions))
    new_model_answers = {
        model_name: {key: value for key, value in answers.items() if key in common_ids}
        for model_name, answers in model_answers.items()
    }
    return new_questions, new_model_answers


def get_data(is_single_eval: bool, model_list: List[str]) -> State:
    texts = get_multilingual_texts()
    questions = {}
    model_answers = {}
    model_names = {}
    for lang_code, language in texts.lang_label.model_dump().items():
        mt_bench_X = f"mt_bench_{lang_code}"
        questions[language] = load_questions(f"data/{mt_bench_X}/question.jsonl", None, None)
        for q in questions[language]:
            q["preview"] = f"{q['question_id']}: " + q["turns"][0][:128] + "..."
        model_answers[language] = load_model_answers(f"data/{mt_bench_X}/model_answer")
        categories = list(set([question["category"] for question in questions[language]]))
        model_names[language] = sorted(
            list(model_answers[language].keys()),
            key=lambda x: sum(map(int, filter(str.isdigit, x.split("-")))),
        )
        if len(model_list) > 0:
            model_names[language] = [model_name for model_name in model_names[language] if model_name in model_list]
            model_answers[language] = {key: val for key, val in model_answers[language].items() if key in model_list}
        # questions, model_answers = filter_common_qa(questions=questions, model_answers=model_answers)

    return State(
        questions=questions,
        model_answers=model_answers,
        model_names=model_names,
        categories=categories,
        texts=texts,
        is_single_eval=is_single_eval,
    )

def get_weights_questions(questions: List[Dict], blacklisted_ids: set):
    df = pd.read_json(LOGFILE, lines=True)
    question_counts = df['question_id'].value_counts()
    question_ids = [q["question_id"] for q in questions]
    # add missing questions with a count of NaN
    question_counts = question_counts.reindex(question_ids)

    # set upper limit count
    max_count = 12
    question_counts.loc[question_counts > max_count] = max_count

    # make counts to weights
    question_weights = question_counts.max() - question_counts
    # missing entries get max count
    question_weights = question_weights.fillna(max_count)

    # remove entries with count 0 and blacklisted entries
    question_weights.loc[question_weights.index.isin(blacklisted_ids)] = 0
    question_positive_weights = question_weights[question_weights > 0]
    return question_positive_weights.index.to_list(), question_positive_weights.to_list()

def get_multilingual_texts() -> AppTexts:
    user_instruction = MultilingualText(
        EN="""
# MT-Bench-X Evaluator
Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. 
Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. 
Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. 
Do not allow the length of the responses to influence your evaluation. 
Do not favor certain names of the assistants. 
Be as objective as possible. 
Finally, indicate your verdict by clicking one button. 
Please use the tie button sparsely.
""",
        DE="""
# MT-Bench-X Bewerter
Bitte bewerten Sie als unparteiischer Richter die Qualität der Antworten von zwei KI-Assistenten auf die unten angezeigte Benutzerfrage. Sie sollten den Assistenten auswählen, der die Anweisungen des Benutzers befolgt und die Frage des Benutzers besser beantwortet. 
Bei Ihrer Bewertung sollten Sie Faktoren wie Hilfsbereitschaft, Relevanz, Genauigkeit, Tiefe, Kreativität und Detailgenauigkeit der Antworten berücksichtigen. 
Vermeiden Sie jegliche Voreingenommenheit und stellen Sie sicher, dass die Reihenfolge, in der die Antworten präsentiert wurden, Ihre Entscheidung nicht beeinflusst. 
Lassen Sie sich bei Ihrer Bewertung nicht von der Länge der Antworten beeinflussen. 
Seien Sie so objektiv wie möglich.
Die Benutzerfragen und eventuelle Referenzantworten sind maschinell vom Englischen ins Deutsche übersetzt wurden. 
Bitte sehen Sie über eventuelle Übersetzungsungenauigkeiten hinweg.
Geben Sie abschließend Ihr Urteil durch Anklicken einer Schaltfläche an. 
""",
        ES="""
# Evaluador de MT-Bench-X
Por favor, actúe como un juez imparcial y evalúe la calidad de las respuestas proporcionadas por dos asistentes de IA a la pregunta del usuario que se muestra a continuación. Deberá elegir el asistente que siga las instrucciones del usuario y responda mejor a su pregunta. 
En tu evaluación deberás tener en cuenta factores como la utilidad, la relevancia, la precisión, la profundidad, la creatividad y el nivel de detalle de las respuestas. 
Evite cualquier sesgo posicional y asegúrese de que el orden en que se presentaron las respuestas no influye en su decisión. 
No permita que la longitud de las respuestas influya en su evaluación. 
No favorezca determinados nombres de los asistentes. 
Sea lo más objetivo posible. 
Por último, indique su veredicto pulsando un botón. 
Utilice el botón de empate con moderación.
""",
        FR="""
# Évaluateur MT-Bench-X
Veuillez agir en tant que juge impartial et évaluer la qualité des réponses fournies par deux assistants IA à la question de l'utilisateur affichée ci-dessous. Vous devez choisir l'assistant qui suit les instructions de l'utilisateur et répond le mieux à sa question. 
Votre évaluation doit prendre en compte des facteurs tels que l'utilité, la pertinence, l'exactitude, la profondeur, la créativité et le niveau de détail des réponses. 
Évitez tout parti pris et veillez à ce que l'ordre dans lequel les réponses ont été présentées n'influence pas votre décision. 
Ne laissez pas la longueur des réponses influencer votre évaluation. 
Ne privilégiez pas certains noms d'assistants. 
Soyez aussi objectif que possible. 
Enfin, indiquez votre verdict en cliquant sur un bouton. 
Veuillez utiliser le bouton "égalité" avec parcimonie.
""",
        IT="""
# Valutatore di MT-Bench-X
Agisci come un giudice imparziale e valuta la qualità delle risposte fornite da due assistenti AI alla domanda dell'utente visualizzata di seguito. Dovete scegliere l'assistente che segue le istruzioni dell'utente e che risponde meglio alla sua domanda. 
La valutazione deve considerare fattori quali l'utilità, la pertinenza, l'accuratezza, la profondità, la creatività e il livello di dettaglio delle risposte. 
Evitate qualsiasi pregiudizio di posizione e assicuratevi che l'ordine di presentazione delle risposte non influisca sulla vostra decisione. 
Non lasciate che la lunghezza delle risposte influenzi la vostra valutazione. 
Non favorire alcuni nomi di assistenti. 
Siate il più possibile obiettivi. 
Infine, indicate il vostro verdetto facendo clic su un pulsante. 
Si prega di utilizzare il pulsante di spareggio in modo limitato.
""",
    )
    # EN="🇬🇧English"
    lang_label = MultilingualText(DE="🇩🇪Deutsch", FR="🇫🇷Français", IT="🇮🇹Italiano", ES="🇪🇸Español")
    lang_label_to_lang_code = {value: key for key, value in lang_label.model_dump().items()}
    return AppTexts(
        user_instruction=user_instruction, lang_label=lang_label, lang_label_to_lang_code=lang_label_to_lang_code
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int)
    parser.add_argument("--model-list", type=str, nargs="+", default=[])
    parser.add_argument("--share", action="store_true")
    parser.add_argument("--single", action="store_true")
    args = parser.parse_args()
    print(args)

    demo = build_demo(is_single_eval=args.single, model_list=args.model_list)
    demo.queue(max_size=None)
    demo.launch(
        server_name=args.host,
        server_port=args.port,
        share=args.share,
        max_threads=200,
        auth=("user", "z7hx30Kjzzz"),
    )
