from typing import Callable, Optional

from . import llm_utils
from experiments import pool

import numpy as np
import polars as pl

from concurrent.futures import ThreadPoolExecutor


def _normalize(s):
    return s.replace("▁", " ")


def _len_common_suffix(a, b):
    i = 0
    for ca, cb in zip(reversed(list(a)), reversed(list(b))):
        if ca == cb:
            i += 1
        else:
            break
    return i


def _fix_tokens_dp(text: str, tokens: list[str]):
    if len(text) == 0:
        return tokens
    if len(tokens) == 0:
        return None
    token = tokens[-1]
    if len(token) == 0:
        fixed_tokens = _fix_tokens_dp(text, tokens[:-1])
        if fixed_tokens is None:
            return None
        else:
            return fixed_tokens + [token]
    max_matched_len = _len_common_suffix(text, token)
    for matched_len in reversed(list(range(1, max_matched_len + 1))):
        sub_text = text[: len(text) - matched_len]
        sub_token = token[len(token) - matched_len :]
        fixed_tokens = _fix_tokens_dp(sub_text, tokens[:-1])
        if fixed_tokens is not None:
            return fixed_tokens + [sub_token]
    return None


def _fix_tokens(text: str, tokens: list[str]):
    import sys

    sys.setrecursionlimit(max(1500, len(tokens) + 100))

    text = _normalize(text)
    tokens = [_normalize(t) for t in tokens]
    all_tokens = "".join(tokens)
    if all_tokens.endswith(text):
        # no need to fix
        return tokens
    fixed_tokens = _fix_tokens_dp(text, tokens)
    # assert fixed_tokens is not None
    return fixed_tokens


def safe_start_indices_by_count_backward(text, tokens):
    loc = len(tokens)
    remain_len = len(text)
    for token in reversed(tokens):
        if remain_len > 0:
            loc -= 1
            remain_len -= len(token)
        else:
            break
    return loc


def safe_start_indices(text: str, tokens: list[str]) -> int:
    original_tokens = tokens
    tokens = _fix_tokens(text, tokens)
    if tokens is None:
        import warnings

        warnings.warn(f"Cannot fix tokens, will use a basic guess")
        tokens = original_tokens
        return safe_start_indices_by_count_backward(text, tokens)
    all_tokens = "".join(tokens)
    assert _normalize(all_tokens).endswith(_normalize(text)), (text, tokens)
    loc = safe_start_indices_by_count_backward(text, tokens)
    assert _normalize("".join(tokens[loc:])).endswith(_normalize(text))
    return loc


def get_logprob(client, prompts, text, **kwargs):
    """
    output:
    logprob: list[float]
    """
    ts = [p + text for p in prompts]
    #  _return_ts, raw_text_offsets, token_logprobs, tokens = llm_utils.generate_logprob(
    with pool.get_req_resource(kwargs):
        _, raw_text_offsets, token_logprobs, tokens = llm_utils.generate_logprob(
            client,
            ts,
            **kwargs,
        )

        
    pool.get_occ_resource(kwargs).release(n=1)
    # This doesn't work due to a bug in vllm: https://github.com/vllm-project/vllm/issues/5334
    #  start_indices = [
    #      # first element >= len(p) is the first token of text
    #      next((i for i, offset in enumerate(offsets) if offset >= len(p)))
    #      for p, offsets in zip(prompts, raw_text_offsets)
    #  ]
    
    start_indices = [safe_start_indices(text, token) for token in tokens]
    logprobs = [
        sum(logprobs[start:]) for start, logprobs in zip(start_indices, token_logprobs)
    ]

    return logprobs




def logprob_classification(
    train_ds,
    test_ds,
    client,
    # input is a DataFrame for the same author
    prompt_composer: Callable[[pl.DataFrame], str],
    # input is a dict representing a row in the test dataset
    text_composer: Callable[[dict], str] = lambda x: x["text"],
    seed: Optional[int] = None,  # will not be used, since logprob is deterministic
    **kwargs,
):
    """
    output:
    out_ds: pl.DataFrame, test_ds with an additional column "pred_authorId", and "rank"
    """
    single_author_dss = train_ds.partition_by("authorId")
    pool.get_occ_resource(kwargs).acquire(n=test_ds.height * len(single_author_dss))
    prompts = [
        prompt_composer(single_author_ds) for single_author_ds in single_author_dss
    ]

    authorIds = [
        single_author_ds.select("authorId")[0, 0]
        for single_author_ds in single_author_dss
    ]

    logprobss = []
    with ThreadPoolExecutor(
        max_workers=test_ds.height * len(single_author_dss)
    ) as executor:
        for test_row in test_ds.iter_rows(named=True):
            text = text_composer(test_row)
            logprobs = []
            logprobss.append(logprobs)
            for prompt in prompts:
                pool.get_req_resource(kwargs).wait_till_available()
                logprobs.append(
                    executor.submit(
                        get_logprob,
                        client,
                        [prompt],
                        text,
                        **kwargs,
                    )
                )
        logprobss = [[l.result()[0] for l in logprobs] for logprobs in logprobss]
    results = [authorIds[np.argmax(logprobs)] for logprobs in logprobss]
    tops = [
        sorted(authorIds, key=lambda x: -logprobs[authorIds.index(x)])
        for logprobs in logprobss
    ]
    out_ds = (
        test_ds.clone()
        .with_columns(pl.Series(name="pred_authorId", values=results))
        .with_columns(pl.Series(name="rank", values=[str(r) for r in tops]))
    )
    return out_ds

    


def test():
    from dotenv import load_dotenv

    load_dotenv()
    import os
    from openai import OpenAI

    client = OpenAI(
        api_key=os.environ["VLLM_API_KEY"], base_url=os.environ["VLLM_ENDPOINT"]
    )
    prompts = [
        "\n1. Once upon a time",
    ]
    text = " there was a"
    models = client.models.list()
    model = models.data[0].id
    print(get_logprob(client, prompts, text, completion_kwargs={"model": model}))

    train_ds = pl.DataFrame(
        {
            "authorId": ["ca", "ca", "cb", "cb"],
            "text": ["aa", "a ", "bb", "b "],
        }
    )
    test_ds = pl.DataFrame(
        {
            "authorId": ["ca", "cb"],
            "text": ["a a", "b b"],
        }
    )
    prompt_composer = lambda x: str(x["text"].to_list()) + "\n\n"
    print(
        logprob_classification(
            train_ds,
            test_ds,
            client,
            prompt_composer,
            completion_kwargs={"model": model},
        )
    )


#  test()
