from typing import Callable, Optional, Union
import json
import numpy as np
import polars as pl
import copy

from . import llm_utils
from experiments import pool

# input is a DataFrame for all authors, and a dict representing a row in the test dataset
# output is string prompt or message dict
PROMPT_COMPOSER = Callable[[pl.DataFrame, dict], Union[str, list[dict]]]


def extract_selected_authors(train_ds: pl.DataFrame) -> list[str]:
    selected_authors = train_ds["authorId"].unique(maintain_order=True).to_list()
    return selected_authors


def default_prompt_composer(
    train_ds, test_row, selected_authors: Optional[list[str]] = None
) -> list[dict]:
    #  source https://arxiv.org/pdf/2403.08213v1
    #  https://github.com/baixianghuang/authorship-llm/blob/main/code/github-attribution-gpt.ipynb
    system_msg = """Respond with a JSON object including two key elements:
{
  "analysis": Reasoning behind your answer.
  "answer": The query text's author ID.
}"""
    prompt1 = "Given a set of texts with known authors and a query text, determine the author of the query text. "
    prompt4 = (
        prompt1
        + "Analyze the writing styles of the input texts, disregarding the differences in topic and content. Focus on linguistic features such as phrasal verbs, modal verbs, punctuation, rare words, affixes, quantities, humor, sarcasm, typographical errors, and misspellings. "
    )
    prompt_input = prompt4  # according to the source, prompt4 is the best

    query_text = test_row["text"]
    if selected_authors is None:
        selected_authors = extract_selected_authors(train_ds)
    train_dict = {}
    for d in train_ds.iter_rows(named=True):
        train_dict.setdefault(selected_authors.index(d["authorId"]), []).append(
            d["text"]
        )
    example_texts = json.dumps(train_dict, sort_keys=True)

    prompt = prompt_input+f"""The input texts are delimited with triple backticks. ```\n\nQuery text: {query_text} \n\nTexts from potential authors: {example_texts}\n\n```"""
    return [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": prompt},
    ]


# input is a DataFrame for all authors, a dict representing a row in the test dataset, and a string representing the responses from llm
# output is a dict, {"pred_authorId": str, **other}
RESPONSE_PARSER = Callable[[pl.DataFrame, dict, str], dict]


def default_response_parser(
    train_ds, test_row, response, selected_authors: Optional[list[str]] = None
) -> dict:
    if selected_authors is None:
        selected_authors = extract_selected_authors(train_ds)
    try:
        response = json.loads(response)
        #  print(response)
        if "Answer" in response.keys():
            return {
                "pred_authorId": selected_authors[int(response["Answer"])],
                "analysis": "",#response["analysis"],
            }
        elif "answer" in response.keys():
            return {
                "pred_authorId": selected_authors[int(response["answer"])],
                "analysis": "",#response["analysis"],
            }
        else:
            return {
                "pred_authorId": "FailWrongFormat",
                "analysis": "",#response["analysis"],
            }
    except (json.JSONDecodeError, ValueError):
        return {"pred_authorId": "FailWrongFormat", "analysis": ""}


def add_seed_to_kwargs(kwargs, seed):
    kwargs = copy.deepcopy(kwargs)
    kwargs.setdefault("completion_kwargs", {})["seed"] = seed
    return kwargs


def qa_classification(
    train_ds,
    test_ds,
    client,
    # input is a DataFrame for the same author
    prompt_composer: PROMPT_COMPOSER,
    # input is a dict representing a row in the test dataset
    response_parser: RESPONSE_PARSER,
    seed: Optional[int] = None,
    **kwargs,
):
    """
    output:
    out_ds: pl.DataFrame, test_ds with an additional column "pred_authorId"
    """
    pool.get_occ_resource(kwargs).acquire(n=test_ds.height)
    selected_authors = extract_selected_authors(train_ds)
    results = []
    rng = np.random.default_rng(seed)
    for test_row in test_ds.iter_rows(named=True):
        seed_for_llm = int(rng.integers(0, 2**32))
        prompt = prompt_composer(train_ds, test_row, selected_authors)
        input_kwargs = {}
        if isinstance(prompt, list):
            # should be a list[dict], but isinstance() argument 2 cannot be a parameterized generic
            input_kwargs["messages"] = prompt
        elif isinstance(prompt, str):
            input_kwargs["prompt"] = prompt
        else:
            raise ValueError("prompt_composer must return a dict or a string")

        # should be fine len(test_ds) == 1, but suboptimal for larger test_ds, where ThreadPoolExecutor() is needed
        # I didn't use ThreadPoolExecutor() because usually the height of test_ds is 1, and it's not worth the overhead
        with pool.get_req_resource(kwargs):
            response_str = llm_utils.generate_qa(
                client, **input_kwargs, **add_seed_to_kwargs(kwargs, seed_for_llm)
            )
        pool.get_occ_resource(kwargs).release(n=1)
        result = response_parser(train_ds, test_row, response_str, selected_authors)
        # print(result)
        results.append(result)
    out_ds = pl.concat(
        [
            test_ds,
            pl.DataFrame(results),
        ],
        how="horizontal",
    )
    return out_ds


def test():
    from dotenv import load_dotenv

    load_dotenv()
    import os
    from openai import OpenAI

    client = OpenAI(
        api_key=os.environ["Llama_3_70B_Instruct_API_KEY"],
        base_url=os.environ["Llama_3_70B_Instruct_ENDPOINT"],
    )
    prompts = [
        "\n1. Once upon a time",
    ]
    print(
        llm_utils.generate_qa(
            client,
            #  messages=[{"user": prompts[0]}],
            messages=[{"role": "user", "content": prompts[0]}],
            completion_kwargs={"model": "Llama-3-70B-Instruct", "seed": 12},
        )
    )

    train_ds = pl.DataFrame(
        {
            "authorId": ["ca", "ca", "cb", "cb"],
            "text": ["aa", "a ", "bb", "b "],
        }
    )
    test_ds = pl.DataFrame(
        {
            "authorId": ["ca", "cb"],
            "text": ["a a", "b b"],
        }
    )
    print(
        qa_classification(
            train_ds,
            test_ds,
            client,
            prompt_composer=default_prompt_composer,
            response_parser=default_response_parser,
            seed=12,
            completion_kwargs={"model": "Llama-3-70B-Instruct", "seed": 12},
        )
    )


#  test()
