import json
import re
import openpyxl
from qa_utils import cal_token
from openai import AzureOpenAI


client_phase2 = AzureOpenAI(
    azure_endpoint ="",
    api_key = "", 
    api_version = "2024-02-15-preview"
)


def get_new_qa_prompt(new_rows, question):

    QA_instruction = r"Instruction: Given an input that is a string denoting data of cells in a table and a question about this table. The answer to the question can be found in the table. The input table includes many pairs, and each pair consists of a cell address and the text in that cell with a ',' in between, like 'A1,Year'. Cells are separated by '|' like 'A1,Year|A2,Profit'. The text can be empty so the cell data is like 'A1, |A2,Profit'. The cells are organized in row-major order. The answer to the input question is contained in the input table and can be represented by cell address. \
        I need you to find the cell address of the answer in the given table based on the given question description, and return the cell ADDRESS of the answer like '{[B3]}' or '{[SUM(A2:A10)]}'. DON'T ADD ANY OTHER WORDS." 

    new_input = ""
    for new_row in new_rows:
        new_input += new_row + "\n"
        
    new_prompt = QA_instruction + "\nQUESTION: " + question + "\nInput: " + new_input  + "\n\n##\n\n"
    return new_prompt
    
def phase2_infer(new_prompt):
    message_text = [{"role": "system", "content": "You are an AI assistant that helps people find information."}, 
                    {"role": "user", "content": new_prompt}]
    completion = client_phase2.chat.completions.create(
        model="va_nfs_fmt0_4k-ft-gpt4-v4",
        messages = message_text,
        temperature=0,
        max_tokens=300,
        top_p=0.99,
        frequency_penalty=0,
        presence_penalty=0,
        stop=None
    )
    
    return f"{completion}".split("message=ChatCompletionMessage(content=")[1].split(", role='assistant'")[0]

def phase2(question, table_rows):
    new_prompt = get_new_qa_prompt(table_rows, question)
    ans = phase2_infer(new_prompt)
    #results = re.findall(r'\{([^}]*)\}', ans)

    return ans

gpt_infer_header = phase2_infer

def get_header_rows(table_rows):
    instruction = r"Instruction: Given an input that is a string denoting data of cells in a table or part of table. The input includes many pairs, and each pair consists of a cell address and the text in that cell with a ',' in between, like 'A1,Year'. Cells are separated by '|' like 'A1,Year|A2,Profit'. The text can be empty so the cell data is like 'A1, |A2,Profit'. The cells are organized in row-major order. I need you to find which ROWS in the input table are header rows and return the answer like '{[1,2,3]}'. For example, if the first and second rows of the input table are header rows, you should return '{[1,2]}'. DON'T ADD ANY OTHER WORDS."
    input = ""
    for table_row in table_rows:
        input += table_row + "\n"
    prompt = instruction + "\nInput: " + input + "\n\n##\n\n"
    ans = gpt_infer_header(prompt)
    print(ans)
    numbers = re.findall(r'\d+', ans)
    num_list = [int(num) for num in numbers]
    adds_pattern = r"([A-Z]+)(\d+)"
    first_address = table_rows[0].split("|")[1:][0].split(",")[0]

    matches = re.match(adds_pattern, first_address)
    numeric_part = matches.group(2)
    row_l = int(numeric_part) - 1
    index_list = []
    for index in num_list:
        index_list.append(index - row_l - 1)
    return index_list

def chunk_infer(rows, question):
    instruction = r"Instruction: Given an input that is a string denoting data of cells in a table or part of table and a question about this table. The input table includes many pairs, and each pair consists of a cell address and the text in that cell with a ',' in between, like 'A1,Year'. Cells are separated by '|' like 'A1,Year|A2,Profit'. The text can be empty so the cell data is like 'A1, |A2,Profit'. The cells are organized in row-major order. The answer to the input question is contained in the input table and can be represented by cell address. First, I need you to judge whether the answer to the question can be found in the input. Second, if it can be found, you should find the cell address of the answer in the given table based on the given problem description, and return the answer like '{[B3]}' or '{[SUM(A2:A10)]}'; if not found, return '{[-1]}'. DON'T ADD ANY OTHER WORDS." 
    input = ""
    for row in rows:
        input += row + "\n"
    prompt = instruction + "\nInput: " + input + "\n\n##\n\n" + "\nQUESTION: " + question
    ans = phase2_infer(prompt)
    #results = re.findall(r'\{([^}]*)\}', ans)
    return ans



def split(question, qa_table_rows):
    res = []
    if len(qa_table_rows) <= 6:
        header_row_index = get_header_rows(qa_table_rows)
    else:
        header_row_index = get_header_rows(qa_table_rows[0:6])
    
    header_rows = []
    for index in header_row_index:
        header_rows.append(qa_table_rows[index])

    remaining_rows = []
    for i in range(0, len(qa_table_rows)):
        if i not in header_row_index:
            remaining_rows.append(qa_table_rows[i])
    for i in range(0, len(remaining_rows), 3):
        qa_chunk = header_rows + remaining_rows[i:i+3]
        ans = chunk_infer(qa_chunk, question)
        #print(ans)
        #results = re.findall(r'\{([^}]*)\}', ans)[0]
        if ans != r"{[-1]}":
            res.append(ans)
        # if ans == -1:
        #     continue
        # else:
        #     return ans
        
    return res

def main():
    count = 0
    with open("QAcode\intermediate_data_method2+3_1.jsonl",'r',encoding='utf-8') as input_file, open("QAcode/output/qa_output_method2+3_1.jsonl",'w',encoding='utf-8') as output_file:
        for line in input_file:
            data = json.loads(line)            
            count += 1
            print(count)
            file_name = data["file_name"]
            question = data["question"]
            print("question:", question)
            answer = data["answer"]
            print("groundtruth:", answer)
            qa_table_rows = data["qa_table_rows"]
                
            if cal_token(qa_table_rows) > 4096:
                ans = split(question, qa_table_rows)
            else:
                ans = phase2(question, qa_table_rows)
            print(ans)
            item = {
                "question": question,
                "prediction": ans,
                "groudtruth": answer,
                "file_name": file_name

            }
            output_file.write(json.dumps(item) + "\n")
main()