import argparse
import json
import time
from request_model import request_model
from utils.verification_prompt_organize import generate_prompt


def parse_arguments():
    """
    Parse command line arguments.
    """
    parser = argparse.ArgumentParser(description='Run model prediction on tasks.')
    parser.add_argument('--model_name', required=True, help='Model name to use for prediction')
    parser.add_argument('--input_file', required=True, help='Path to the input file')
    parser.add_argument('--output_file', required=True, help='Path to the output file')
    parser.add_argument('--re_retrieve_file', required=True, help='Path to the output file')
    parser.add_argument('--url', default="http://192.168.80.2:8085/v1/completions", help='URL of the model server')
    parser.add_argument('--max_tokens', type=int, default=1024, help='Maximum number of tokens for the model response')
    return parser.parse_args()


def save_re_retrieve_question(input_fn, output_fn):
    """
    Process each task in the input file and save questions that need re-retrieval to the output file.

    Args:
        input_fn (str): Path to the input file containing tasks.
        output_fn (str): Path to the output file to write tasks requiring re-retrieval.

    Returns:
        None
    """
    with open(input_fn, "r") as fin, open(output_fn, "w") as fout:
        for idx, line in enumerate(fin):
            task = json.loads(line)

            try:
                verification_out = json.loads(task["model_predict"])
                nli = verification_out["3"]
                revise = verification_out["4"]
                if nli == "false":
                    task["question_2rd"] = revise
                    fout.write(json.dumps(task, ensure_ascii=False) + "\n")
            except Exception as e:
                print(f"Error processing task at line {idx + 1}: {e}")


def run_prediction(model, input_fn, output_fn, re_retrieve_fn, url, max_tokens=1024):
    """
    Run prediction for each task in the input file and write the output to the output file.

    Args:
        model (str): The model name to use for prediction.
        input_fn (str): Path to the input file.
        output_fn (str): Path to the output file.
        re_retrieve_fn (str): Path to the Re-retrieve question file.
        url (str): URL of the model server.
        max_tokens (int): Maximum number of tokens for the model response.
    """
    with open(input_fn, "r") as fin, open(output_fn, "w") as fout:
        for idx, line in enumerate(fin):
            task = json.loads(line)
            start = time.time()

            prompt = generate_prompt(task)

            response = ""
            for _ in range(3):
                try:
                    response = request_model(prompt, model=model, url=url, max_tokens=max_tokens)
                    break
                except Exception as e:
                    print(f"Error on task {idx + 1}: {e}")

            end = time.time()
            print(f"Task {idx + 1}: response={response}, cost={end - start:.2f}s")

            task["model_predict"] = response
            fout.write(json.dumps(task, ensure_ascii=False) + "\n")
    save_re_retrieve_question(output_fn, re_retrieve_fn)


if __name__ == "__main__":
    args = parse_arguments()
    run_prediction(args.model, args.input, args.output, args.re_retrieve_file, args.url, args.max_tokens)

