from dotenv import load_dotenv
load_dotenv()

from . import config, client_config
from . import tasks, algorithms, pool
from .algorithms import llm_cache, timed_llm

import os
from functools import partial
import numpy as np
import openai
import polars as pl
import tqdm
import time, json
import functools

import multiprocessing
from multiprocessing import Process, Queue
from concurrent.futures import ThreadPoolExecutor, TimeoutError


@functools.cache
def _get_ds(ds_name, **filter_config):
    if ds_name == "imdb":
        ds = tasks.get_imdb_ds(**filter_config)
    elif ds_name == "blog":
        ds = tasks.get_blog_ds(**filter_config)
    return ds

def _get_clinet(config):
    from .client_config import get_client_type

    model_str = config["completion_kwargs"]["model"]
    if get_client_type(model_str) == "openai":
        client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
    elif get_client_type(model_str) == "vllm":
        #  vllm
        endpoint = os.environ["VLLM_ENDPOINT"]
        api_key = os.environ["VLLM_API_KEY"]
        client = openai.OpenAI(api_key=api_key, base_url=endpoint)

        # confirm client
        try:
            models = client.models.list()
            if model_str in [item.id for item in models.data]:
                pass
            else:
                import warnings

                warnings.warn(
                    f"Can't find model {model_str} in vllm. Make sure the model is deployed"
                )
        except Exception as e:
            import warnings

            warnings.warn(
                f"Can't access vllm. Make sure the endpoint and api_key are correct"
            )
    elif get_client_type(model_str) == "azure":
        model_str = config["completion_kwargs"]["model"]
        endpoint = os.environ[f"{model_str.replace('-', '_')}_ENDPOINT"]
        api_key = os.environ[f"{model_str.replace('-', '_')}_API_KEY"]
        client = openai.OpenAI(api_key=api_key, base_url=endpoint)
    return client


def _get_clasify_fn(config, client):
    if config["method"] == "logprob_classification":
        classify_fn = partial(
            algorithms.logprob.logprob_classification,
            client=client,
            prompt_composer=config["prompt_composer"],
            completion_kwargs=config["completion_kwargs"],
        )
    elif config["method"] == "qa":
        classify_fn = partial(
            algorithms.qa.qa_classification,
            client=client,
            prompt_composer=config["prompt_composer"],
            response_parser=config["response_parser"],
            completion_kwargs=config["completion_kwargs"],
        )

    return classify_fn


def worker(q, classify_fn, train_ds, test_ds):
    out_ds = classify_fn(train_ds, test_ds)
    assert not isinstance(out_ds, pl.LazyFrame)
    q.put((train_ds.to_pandas(), test_ds.to_pandas(), out_ds.to_pandas()))


def put_worker(q, seeds, config):
    client = _get_clinet(config)
    classify_fn = _get_clasify_fn(config, client)
    ds = _get_ds(config["ds_name"], **config.get("filter_config", {}))
    llm_cache.use_cache[0] = config["use_cache"]
    futures = []

    def running_futures(futures, non_blocking):
        for f in futures:
            try:
                if non_blocking:
                    _r = f.result(timeout=0)
                else:
                    _r = f.result(timeout=None)
            except TimeoutError:
                pass
        return [f for f in futures if not f.done()]

    with ThreadPoolExecutor(max_workers=pool.get_max_occ(config)) as executor:
        for seed in seeds:
            rng = np.random.RandomState(seed)
            seed_for_ds_split = rng.randint(0, 2**32)
            train_ds, test_ds = tasks.make_n_author_m_shot_ds(
                ds,
                seed_for_ds_split,
                config["num_authors"],
                config["num_shots"],
                config["num_test_rows"],
            )
            assert not isinstance(train_ds, pl.LazyFrame)
            assert not isinstance(test_ds, pl.LazyFrame)
            seed_for_algorithm = rng.randint(0, 2**32)

            # expand for catching exception
            #  pool.get_occ_resource(config).wait_till_available()
            res = pool.get_occ_resource(config)
            with res._cond:
                while res._value <= 0:
                    future = running_futures(futures, non_blocking=True)
                    res._cond.wait(timeout=1)
            futures.append(
                executor.submit(
                    worker,
                    q,
                    partial(classify_fn, seed=seed_for_algorithm),
                    train_ds,
                    test_ds,
                )
            )

            futures = running_futures(futures, non_blocking=True)
        futures = running_futures(futures, non_blocking=False)
        assert len(futures) == 0


def save_worker(q, num, csv_save_path):
    bar = tqdm.tqdm(total=num)
    already_write_header = False
    for _ in range(num):
        t = q.get()
        if t is None:
            break
        train_ds, test_ds, out_ds = t
        train_ds, test_ds, out_ds = (
            pl.from_pandas(train_ds),
            pl.from_pandas(test_ds),
            pl.from_pandas(out_ds),
        )
       
        train_textIds = train_ds["textId"].sort().to_list()
        out_ds = out_ds.with_columns(train_textIds=pl.lit(str(train_textIds)))
        if not already_write_header:
            with open(csv_save_path, "w") as f:
                out_ds.write_csv(f, include_header=True)
            already_write_header = True
        else:
            with open(csv_save_path, "a") as f:
                out_ds.write_csv(f, include_header=False)
        bar.update(1)


def run_exp_one(config):
    print(f'Run {config["name"]}')
    csv_save_path = os.path.join(
        os.path.dirname(__file__), "..", "data", config["name"], "data.csv"
    )
    json_save_path = os.path.join(
        os.path.dirname(__file__), "..", "data", config["name"], "data.json"
    )
    os.makedirs(os.path.dirname(csv_save_path), exist_ok=True)

    timed_llm.clear_elapsed_time()
    time_start = time.time()

    seeds = np.random.RandomState(config["seed"]).randint(0, 2**32, config["num_try"])
    q = Queue()
    with ThreadPoolExecutor(max_workers=1) as executor:
        future = executor.submit(save_worker, q, len(seeds), csv_save_path)
        try:
            put_worker(q, seeds, config)
        except Exception as e:
            future.cancel()
            q.put(None)
            raise
        future.result()

    time_end = time.time()
    json_data = {
        "total_time_elapsed": time_end - time_start,
        "llm_time_elapsed": timed_llm.get_elapsed_time(),
    }
    json.dump(json_data, open(json_save_path, "w"))


def run_exp(config):
    if isinstance(config, dict):
        run_exp_one(config)
    elif isinstance(config, list):
        for c in config:
            run_exp(c)
