def forillustration():
    import os, openai
    from . import tasks, algorithms

    ds = tasks.get_imdb_ds()
    train_ds, test_ds = tasks.make_n_author_m_shot_ds(
        ds,
        1,
        num_authors=2,
        num_shots=1,
        num_test_rows=1,
    )
    prompt_composer = lambda x: (
        "\n\n".join(x["text"].to_list())
        + "\n\nHere is the text from the same author:\n\n"
    )
    endpoint = os.environ["VLLM_ENDPOINT"]
    api_key = os.environ["VLLM_API_KEY"]
    client = openai.OpenAI(api_key=api_key, base_url=endpoint)
    out_ds = algorithms.logprob.logprob_classification(
        train_ds,
        test_ds,
        client,
        prompt_composer=prompt_composer,
        completion_kwargs={"model": "NousResearch/Meta-Llama-3-70B"},
    )
    single_author_dss = train_ds.partition_by("authorId")
    prompts = [
        prompt_composer(single_author_ds) for single_author_ds in single_author_dss
    ]
    text = test_ds["text"].to_list()[0]
    logprobs = []
    for prompt in prompts:
        logprobs.append(
            algorithms.logprob.get_logprob(
                client,
                [prompt],
                text,
                completion_kwargs={"model": "NousResearch/Meta-Llama-3-70B"},
            )
        )
    print(out_ds)
    print(text)
    for prompt, logprob in zip(prompts, logprobs):
        print(logprob)
        print(prompt)


def run_all_exps():
    from . import config, run_exp

    #  run_exp(config.debug_exp1)
    #  run_exp(config.debug_exp2)
    #  run_exp(config.debug_exp3)
    #  run_exp(config.debug_exp4)
    run_exp(config.ablation_exp1)
    run_exp(config.ablation_exp2)
    run_exp(config.ablation_exp3)
    run_exp(config.ablation_exp4)
    run_exp(config.ablation_exp5)
    run_exp(config.ablation_exp6)
    run_exp(config.compare_exp1)
    run_exp(config.compare_exp2)
    run_exp(config.large_exp1)


def run_benchmark():
    from . import pool

    # We only host one vllm instance with --max-num-seqs=1, which handles one seq at a
    # time, and other requests are queued. To avoid counting the waiting time in the
    # queue, we set the max_concurrency_dict to 1 when running the benchmark for vllm.
    # For Azure and OpenAI, they can handle multiple requests at the same time, so we
    # set their max_concurrency_dict to normal values.
    pool.req_resource_dict["vllm"] = pool.Resource(1)

    from . import client_config

    for t in client_config.max_concurrency_dict:
        client_config.max_concurrency_dict[t] = 1

    from . import config, run_exp

    run_exp(config.compare_exp3)


if __name__ == "__main__":

    #  forillustration()
    run_all_exps()
    #  run_benchmark()
