import json
import re
import string
import tiktoken
from qa_utils import cal_token, excel_address_to_coords, col_letter_to_num, parse_excel_row
from openai import AzureOpenAI


client_phase2 = AzureOpenAI(
    azure_endpoint ="", 
    api_key = "7a10f29c4ba74b6a96383ee4b937808a",
    api_version = "2024-02-15-preview"
)

def convert_range(range_str):
    def excel_col_range(start, end):
        def col_to_num(col):
            num = 0
            for char in col:
                num = num * 26 + (ord(char.upper()) - ord('A') + 1)
            return num
        
        def num_to_col(num):
            col = ""
            while num > 0:
                num, remainder = divmod(num - 1, 26)
                col = chr(65 + remainder) + col
            return col
        
        start_num = col_to_num(start)
        end_num = col_to_num(end)
        return [num_to_col(i) for i in range(start_num, end_num + 1)]

    def range_to_list(start, end):
        if start.isdigit() and end.isdigit():
            return list(map(str, range(int(start), int(end) + 1)))
        elif start.isalpha() and end.isalpha():
            return excel_col_range(start, end)
        else:
            raise ValueError("Invalid range format")

    result = []
    ranges = range_str.split(',')
    for r in ranges:
        if ':' in r:
            start, end = r.split(':')
            result.extend(range_to_list(start.strip(), end.strip()))
        else:
            result.append(r.strip().upper())
    
    return result

def get_new_qa_prompt(new_rows, question):

    QA_instruction = r"Instruction: Given an input that is a string denoting data of cells in a table and a question about this table. The answer to the question can be found in the table. The input table includes many pairs, and each pair consists of a cell address and the text in that cell with a ',' in between, like 'A1,Year'. Cells are separated by '|' like 'A1,Year|A2,Profit'. The text can be empty so the cell data is like 'A1, |A2,Profit'. The cells are organized in row-major order. The answer to the input question is contained in the input table and can be represented by cell address. \
        I need you to find the cell address of the answer in the given table based on the given question description, and return the cell ADDRESS of the answer like '{[B3]}' or '{[SUM(A2:A10)]}'. DON'T ADD ANY OTHER WORDS." 

    new_input = ""
    for new_row in new_rows:
        new_input += new_row + "\n"
        
    new_prompt = QA_instruction + "\nQUESTION: " + question + "\nInput: " + new_input  + "\n\n##\n\n"
    return new_prompt
    
def phase2_infer(new_prompt):
    message_text = [{"role": "system", "content": "You are an AI assistant that helps people find information."}, 
                    {"role": "user", "content": new_prompt}]
    completion = client_phase2.chat.completions.create(
        model="va_nfs_fmt0_4k-ft-gpt4-v4",
        messages = message_text,
        temperature=0,
        max_tokens=300,
        top_p=0.99,
        frequency_penalty=0,
        presence_penalty=0,
        stop=None
    )
    
    return f"{completion}".split("message=ChatCompletionMessage(content=")[1].split(", role='assistant'")[0]

def phase2(question, table_rows):
    new_prompt = get_new_qa_prompt(table_rows, question)
    ans = phase2_infer(new_prompt)
    #results = re.findall(r'\{([^}]*)\}', ans)

    return ans

gpt_infer_header = phase2_infer

def get_header_rows(table_rows):
    instruction = r"Instruction: Given an input that is a string denoting data of cells in a table or part of table. The input includes many pairs, and each pair consists of a cell address and the text in that cell with a ',' in between, like 'A1,Year'. Cells are separated by '|' like 'A1,Year|A2,Profit'. The text can be empty so the cell data is like 'A1, |A2,Profit'. The cells are organized in row-major order. I need you to find which ROWS in the input table are header rows and return the answer like '{[1,2,3]}'. For example, if the first and second rows of the input table are header rows, you should return '{[1,2]}'. DON'T ADD ANY OTHER WORDS."
    input = ""
    for table_row in table_rows:
        input += table_row + "\n"
    prompt = instruction + "\nInput: " + input + "\n\n##\n\n"
    ans = gpt_infer_header(prompt)
    print(ans)
    numbers = re.findall(r'\d+', ans)
    num_list = [int(num) for num in numbers]
    adds_pattern = r"([A-Z]+)(\d+)"
    first_address = table_rows[0].split("|")[1:][0].split(",")[0]

    matches = re.match(adds_pattern, first_address)
    numeric_part = matches.group(2)
    row_l = int(numeric_part) - 1
    index_list = []
    for index in num_list:
        index_list.append(index - row_l - 1)
    return index_list

def chunk_infer(rows, question):
    instruction = r"Instruction: Given an input that is a string denoting data of cells in a table or part of table and a question about this table. The input table includes many pairs, and each pair consists of a cell address and the text in that cell with a ',' in between, like 'A1,Year'. Cells are separated by '|' like 'A1,Year|A2,Profit'. The text can be empty so the cell data is like 'A1, |A2,Profit'. The cells are organized in row-major order. The answer to the input question is contained in the input table and can be represented by cell address. First, I need you to judge whether the answer to the question can be found in the input. Second, if it can be found, you should find the cell address of the answer in the given table based on the given problem description, and return the answer like '{[B3]}' or '{[SUM(A2:A10)]}'; if not found, return '{[-1]}'. DON'T ADD ANY OTHER WORDS." 
    input = ""
    for row in rows:
        input += row + "\n"
    prompt = instruction + "\nInput: " + input + "\n\n##\n\n" + "\nQUESTION: " + question
    ans = phase2_infer(prompt)
    #results = re.findall(r'\{([^}]*)\}', ans)
    return ans

def split(question, qa_table_rows):
    res = []
    if len(qa_table_rows) <= 6:
        header_row_index = get_header_rows(qa_table_rows)
    else:
        header_row_index = get_header_rows(qa_table_rows[0:6])
    
    header_rows = []
    for index in header_row_index:
        header_rows.append(qa_table_rows[index])

    remaining_rows = []
    for i in range(0, len(qa_table_rows)):
        if i not in header_row_index:
            remaining_rows.append(qa_table_rows[i])
    for i in range(0, len(remaining_rows), 3):
        qa_chunk = header_rows + remaining_rows[i:i+3]
        ans = chunk_infer(qa_chunk, question)
        #print(ans)
        #results = re.findall(r'\{([^}]*)\}', ans)[0]
        if ans != r"{[-1]}":
            res.append(ans)
        # if ans == -1:
        #     continue
        # else:
        #     return ans
        
    return res

def find(file_name):
    with open("QAcode\qadata_method2+3.jsonl",'r',encoding='utf-8') as file:
        for line in file:
            data = json.loads(line)
            if data["file_name"] == file_name:
                if data["now_length"] != "<4k":
                    return -1
                
                return data["messages"][1]["content"].split("DON'T ADD OTHER WORDS OR EXPLANATION.")[1]


def coordinate_mapping(address, file_name):
    with open("QAcode\qadata_coordinate_mapping.jsonl",'r',encoding='utf-8') as file:
        for line in file:
            data = json.loads(line)
            if data["file_name"] == file_name:
                position_dic = data["reflection"]

    
    
    return position_dic[address]

def get_related_rows(question, qa_table_rows, file_name):
    table_input = find(file_name)
    if table_input == -1:
        return -1
    instruction = r"Instruction: Given an input that is a string denoting data of cells in a Excel spreadsheet. The input spreadsheet contains many tuples, describing the cells with content in the spreadsheet. Each tuple consists of two elements separated by a '|': the cell content and the cell address/region, like (Year|A1), ( |A1) or (IntNum|A1:B3). The content in some cells such as '#,##0'/'d-mmm-yy'/'H:mm:ss',etc., represents the CELL DATA FORMATS of Excel. The content in some cells such as 'IntNum'/'DateData'/'EmailData',etc., represents a category of data with the same format and similar semantics. For example, 'IntNum' represents integer type data, and 'ScientificNum' represents scientific notation type data. 'A1:B3' represents a region in spreadsheet, from the first row to the third row and from column A to column B. Some cells with empty content in the spreadsheet are not entered.\nI will give you a question, please find the rows and columns in the table related to the answer to the question. Use column name and row number to represent the relevant rows and columns, like'{[1,2,3]} AND {[A,B,C]}' or '{[1:3]} AND {[A:C]}'. DON'T ADD ANY OTHER WORDS."
    new_prompt = instruction + "\nQUESTION: " + question + table_input
    if len(tiktoken.get_encoding("cl100k_base").encode(new_prompt)) > 4096:
        return -1
    ans =  phase2_infer(new_prompt)
    matches = re.findall(r'\{\[(.*?)\]\}', ans)
    print(matches)
    if len(matches) != 2:
        return -1
    if ":" in matches[0]:
        row_index_list = convert_range(matches[0])
    else:
        row_index_list = matches[0].strip().split(",")
    if ":" in matches[1]:
        col_index_list = convert_range(matches[1])
    else:
        col_index_list = matches[1].strip().split(",")

    if len(row_index_list) == 0 or len(col_index_list) == 0:
        return -1
    
    # row_waiting_reflect, col_waiting_reflect = [], []
    # for row_index in row_index_list:
    #     row_waiting_reflect.append(col_index_list[0] + row_index)
    # for col_index in col_index_list:
    #     col_waiting_reflect.append(col_index + row_index_list[0])

    # row_after_reflect, col_after_reflect = [], []
    # for r in row_waiting_reflect:
    #     row_after_reflect.append(coordinate_mapping(r, file_name))
    # for c in col_waiting_reflect:
    #     col_after_reflect.append(coordinate_mapping(c, file_name))

    # row_index_list_final, col_index_list_final = [], []
    # for row_address in row_index_list:
    #     row_index_list_final.append(int(row_address) - 1)
    # for col_address in col_index_list:
    #     col_index_list_final.append(col_letter_to_num(col_address) - 1)     #这要-1么？
    input = []
    for row in qa_table_rows:
        input.append(parse_excel_row(row))
    res = []
    res_row = []
    res_col = []
    for i in range(0, len(input)):
        row = ""        
        for j in range(0, len(input[0])):
            cell_address = input[i][j].split(",", 1)[0]
            if cell_address[1] in row_index_list and cell_address[0] in col_index_list:
                row += "|" + input[i][j]
        res_row.append(row + "\n") 
    

    # for j in range(0, len(input[0])):  
    #     col = ""  
    #     for i in range(0, len(input)):        
    #         cell_address = input[i][j].split(",", 1)[0]
    #         if cell_address[0] in col_index_list:
    #             col += "|" + input[i][j]
    #             res_col.append(col) 
    
    
    
    if len(res_row) == 0:
        return -1
    return res_row
    
def main():
    count = 0
    with open("QAcode\intermediate_data_method1+2.jsonl",'r',encoding='utf-8') as input_file, open("QAcode/output/qa_output_method_0605_1.jsonl",'w',encoding='utf-8') as output_file:
        for line in input_file:
            data = json.loads(line)            
            count += 1
            print(count)
            if count >= 146:
                file_name = data["file_name"]
                question = data["question"]
                print("question:", question)
                answer = data["answer"]
                print("groundtruth:", answer)
                qa_table_rows = data["qa_table_rows"]

                if cal_token(qa_table_rows) > 4096:
                    related_qa_rows = get_related_rows(question, qa_table_rows, file_name)
                    #print(related_qa_rows)
                    if related_qa_rows == -1 or cal_token(related_qa_rows) > 4096:
                        ans = split(question, qa_table_rows)
                    else:
                        ans = phase2(question, related_qa_rows)
                else:
                    ans = phase2(question, qa_table_rows)
                print(ans)
                item = {
                    "question": question,
                    "prediction": ans,
                    "groudtruth": answer,
                    "file_name": file_name

                }
                output_file.write(json.dumps(item) + "\n")
main()  