import json
import time
import argparse
from request_model import request_model


def parse_arguments():
    """

    Returns:

    """
    parser = argparse.ArgumentParser(description='Run model prediction on tasks.')
    parser.add_argument('--model_name', required=True, help='Model name to use for prediction')
    parser.add_argument('--input_file', required=True, help='Path to the input file')
    parser.add_argument('--output_file', required=True, help='Path to the output file')
    parser.add_argument('--url', default="http://192.168.80.2:8085/v1/completions", help='URL of the model server')
    parser.add_argument('--max_tokens', type=int, default=2048, help='Maximum number of tokens for the model response')
    return parser.parse_args()


def run_prediction(model, input_fn, output_fn, url, max_tokens=1024):
    """
    Run prediction for each task in the input file and write the output to the output file.

    Args:
        model (str): The model name to use for prediction.
        input_fn (str): Path to the input file.
        output_fn (str): Path to the output file.
        url (str): URL of the model server.
        max_tokens (int): Maximum number of tokens for the model response.
    """
    with open(input_fn, "r") as fin, open(output_fn, "w") as fout:
        for idx, line in enumerate(fin):
            task = json.loads(line)
            start = time.time()

            response = ""
            for _ in range(3):
                try:
                    response = request_model(task.get("instruction") or task.get("prompt"), model=model, url=url,
                                             max_tokens=max_tokens)
                    break
                except Exception as e:
                    print(f"Error on task {idx + 1}: {e}")

            end = time.time()
            print(f"Task {idx + 1}: response={response}, cost={end - start:.2f}s")

            task["model_predict"] = response
            fout.write(json.dumps(task, ensure_ascii=False) + "\n")


if __name__ == "__main__":
    args = parse_arguments()
    run_prediction(args.model_name, args.input_file, args.output_file, args.url, args.max_tokens)
