import atexit
import json
import os
import time

import requests
from msal import PublicClientApplication, SerializableTokenCache

TOKEN = ""

class LLMClient:
    _ENDPOINT = 'https://httpqas26-frontend-qasazap-prod-dsm02p.qas.binginternal.com/completions'
    _SCOPES = ['api://68df66a4-cad9-4bfd-872b-c6ddde00d6b2/access']

    def __init__(self):
        self._cache = SerializableTokenCache()
        atexit.register(lambda:
                        open('.llmapi.bin', 'w').write(self._cache.serialize())
                        if self._cache.has_state_changed else None)

        self._app = PublicClientApplication('68df66a4-cad9-4bfd-872b-c6ddde00d6b2',
                                            authority='https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47',
                                            token_cache=self._cache)
        if os.path.exists('.llmapi.bin'):
            self._cache.deserialize(open('.llmapi.bin', 'r').read())

    def send_request(self, model_name, request):
        # get the token
        if TOKEN != "":
            token = TOKEN
        else:
            token = self._get_token()

        # populate the headers
        headers = {
            'Content-Type': 'application/json',
            'Authorization': 'Bearer ' + token,
            'X-ModelType': model_name}

        body = str.encode(json.dumps(request))
        response = requests.post(LLMClient._ENDPOINT, data=body, headers=headers)
        return response.json()

    def send_stream_request(self, model_name, request):
        # get the token
        if TOKEN != "":
            token = TOKEN
        else:
            token = self._get_token()

        # populate the headers
        headers = {
            'Content-Type': 'application/json',
            'Authorization': 'Bearer ' + token,
            'X-ModelType': model_name}

        body = str.encode(json.dumps(request))
        response = requests.post(LLMClient._ENDPOINT, data=body, headers=headers, stream=True)
        for line in response.iter_lines():
            text = line.decode('utf-8')
            if text.startswith('data: '):
                text = text[6:]
                if text == '[DONE]':
                    break
                else:
                    yield json.loads(text)

    def _get_token(self):
        accounts = self._app.get_accounts()
        result = None

        if accounts:
            # Assuming the end user chose this one
            chosen = accounts[0]

            # Now let's try to find a token in cache for this account
            result = self._app.acquire_token_silent(LLMClient._SCOPES, account=chosen)

        if not result:
            # So no suitable token exists in cache. Let's get a new one from AAD.
            flow = self._app.initiate_device_flow(scopes=LLMClient._SCOPES)

            if "user_code" not in flow:
                raise ValueError(
                    "Fail to create device flow. Err: %s" % json.dumps(flow, indent=4))

            print(flow["message"])

            result = self._app.acquire_token_by_device_flow(flow)

        return result["access_token"]

    def run_llm(self, prompt,temperature=0.0,max_tokens=3200,top_p=1,frequency_penalty=0.0,presence_penalty=0.0,stop=None,logprobs=None,n=1):
        request_data = {
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "n": n,
            "stream": False,
            "logprobs": logprobs,
            "stop": stop,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty
        }
        cnt = 0
        while True:
            try:
                # response = self.send_request('dev-moonshot', request_data)
                # response = self.send_request('dev-ppo', request_data)
                response = self.send_request('dev-gpt-4-turbo', request_data)
                if 'error' in response and "This model's maximum context length is" in response['error']['message']:
                    print(response['error']['message'])
                    return None
                # result = response['choices'][0]['text']
                time.sleep(10)
                return response
            except:
                print(f"retrying with {cnt} times...")
                cnt += 1
                time.sleep(1)


def eval(data_path='../data/structureWeb_infer_v1.jsonl'):
    ans_path = data_path.replace(".jsonl", "_ans.jsonl")
    cnt = 0
    llm_client = LLMClient()
    writer = open(ans_path, 'a')
    with open(data_path, 'r', encoding="utf-8-sig") as f:
        json_list = list(f)
        json_list = [json.loads(json_str) for json_str in json_list]
        for prompt_data in json_list:
            request_data = {
                "prompt": prompt_data["prompt"] + "\r\n",
                "max_tokens": 95,
                "temperature": 0.0,
                "top_p": 1,
                "n": 1,
                "stream": False,
                "logprobs": None,
                "stop": "\n"
            }
            print(f"processing file {cnt}...", end="")
            cnt += 1
            response = llm_client.send_request('dev-moonshot', request_data)
            ans = response['choices'][0]['text']
            data = {"id": prompt_data["id"], "completion": ans}
            json.dump(data, writer)
            writer.flush()
            print("sleeping...")
            time.sleep(15)


if __name__ == "__main__":
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--data_path', type=str, default='1_shot_metadata_sample200_2.jsonl')
    # args = parser.parse_args()
    # eval(args.data_path)
    prompt = "Help me rewrite the code. I will provide the PROBLEM description, the code for this PROBLEM, and the execution result of this code. Help me rewrite it into the correct code to solve this PROBLEM.\nProblem:\nI'm using tensorflow 2.10.0.\nI would like to generate 10 random integers as a tensor in TensorFlow but I don't which command I should use. In particular, I would like to generate from a uniform random variable which takes values in {1, 2, 3, 4}. I have tried to look among the distributions included in tensorflow_probability but I didn't find it.\nPlease set the random seed to 10 with tf.random.ser_seed().\nThanks in advance for your help.\n\nA:\n<code>\nimport tensorflow as tf\n\ndef f(seed_x=10):\n    # return the solution in this function\n    # result = f(seed_x)\n    ### BEGIN SOLUTION\n--------------------\nHere is a code snippet that may contain errors in solving the above PROBLEM:\n\n    # set the random seed\n    tf.random.set_seed(seed_x)\n    # generate 10 random integers from 1 to 4\n    result = tf.random.uniform(shape=(10,), minval=1, maxval=5, dtype=tf.int32)\n    ### END SOLUTION\n    return result\n\nresult = f(seed_x=10)\nprint(result)\n\n--------------------\nThis is the code that GPT4 generated for me, here are the inputs as well as the execution results. You need to determine if the code is correct and suggest changes if it is not.\n\nThe input is:\n10\n\nUpon executing the above code, the following results were obtained:\ntf.Tensor([3 2 3 3 3 1 1 4 2 2], shape=(10,), dtype=int32)\n\n\n\n\n\n-------------------\nThere are some rules that you must follow for rewriting the code:\n+ Is the code execution result the right answer to the PROBLEM?? If not, please rewrite the code, if yes, please do not return any code.\n+ If you need to rewrite the code:\n    + You need to first explain why the original code is incorrect.\n    + You only need to output codes that can complete the code snippet. You do not need to output the codes before and after the [insert] block or BEGIN SOLUTION block.\n    + You should answer only one code snippet, not more than one!\n+ If you do not need to rewrite the code, do not return any code, return `<original code>`.\n\n-------------------\nHere are some examples, you need to learn the format of completing the code from these examples.\n<example>\nProblem:\nHow can I delete the last row in a numpy array?\n\nA:\n<code>\nimport numpy as np\n\narray = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n\ndef f(a=array):\n    # return the solution in this function    \n    # result = f(a)                           \n    ### BEGIN SOLUTION\n\n<code>\n### BEGIN SOLUTION\n    # Delete the last row\n    result = np.delete(a, -1, 0)\n    return result\n### END SOLUTION\n\n</example>\n\n\n<example>\nProblem:\nHow can I delete the last row in a numpy array?\n\nA:\n<code>\nimport numpy as np\n\narray = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n\n</code>\nresult = ... # put solution in this variable\nBEGIN SOLUTION\n<code>\n\n<code>\n### BEGIN SOLUTION\n# Delete the last row\nresult = np.delete(a, -1, 0)\n### END SOLUTION\n\n</example>                                \n-------------------\nAgain, the PROBLEM is as follows:\n\nProblem:\nI'm using tensorflow 2.10.0.\nI would like to generate 10 random integers as a tensor in TensorFlow but I don't which command I should use. In particular, I would like to generate from a uniform random variable which takes values in {1, 2, 3, 4}. I have tried to look among the distributions included in tensorflow_probability but I didn't find it.\nPlease set the random seed to 10 with tf.random.ser_seed().\nThanks in advance for your help.\n\nA:\n<code>\nimport tensorflow as tf\n\ndef f(seed_x=10):\n    # return the solution in this function\n    # result = f(seed_x)\n    ### BEGIN SOLUTION\n\n<code>\n### BEGIN SOLUTION"
    print(prompt)
    llm_client = LLMClient()
    response = llm_client.run_llm(prompt)
    print(response)
