import openpyxl
import json
import re
import tiktoken
import config
from utils import parse_excel_row
from utils import get_table_nfs_input
from utils import get_table_cell_input
from utils import col_num_to_letter
from utils import rows_to_new_input



class MyCustomException(Exception):
    pass

def find_boundary(file_name):
    res = []
    with open(config.FIND_FILE_PATH,'r',encoding='utf-8') as find_file:
        for line in find_file:
            data = json.loads(line)
            if data["file_name"] == file_name:
                return data["boundarys"]
        
        return res

def row_delete(rows,row_boundarys):
    first = rows[0].split("|")[1:][0].split(",")[0]

    pattern = r"([A-Z]+)(\d+)"
    matches = re.match(pattern, first)

    if matches:
        numeric_part = matches.group(2)

    else:
        print("chucuol!!!!!!!")
    l = int(numeric_part) - 1  

    tag = [0]*(len(rows)+1) 
    for rb in row_boundarys:
        rb = rb - l
        if rb > len(rows):  rb = len(rows)
        tag[rb] = 1
        if (rb-config.DELTA)>=1 and (rb+config.DELTA)<=(len(rows)):   
            for i in range(rb-config.DELTA,rb+config.DELTA+1):
                tag[i] = 1
        if (rb-config.DELTA)<1 and (rb+config.DELTA)<=(len(rows)):   
            for i in range(1,rb+config.DELTA+1):
                tag[i] = 1
        if (rb-config.DELTA)>=1 and (rb+config.DELTA)>(len(rows)):    
            for i in range(rb-config.DELTA,len(rows)+1):
                tag[i] = 1
        if (rb-config.DELTA)<1 and (rb+config.DELTA)>(len(rows)):
            for i in range(1,len(rows)+1):
                tag[i] = 1
    res = []

    for j in range(0,len(rows)):
        if(tag[j+1]==1):
            res.append(rows[j])
    return res

def col_delete(rows,col_boundarys):
    
    first = rows[0].split("|")[1:][0].split(",")[0]

    patternpp = r"([A-Z]+)(\d+)"
    matches = re.match(patternpp, first)
    if matches:
        uppercase_part = matches.group(1)
        
    else:
        print("chucuol!!!!!!!!!!!!!!!")
    l = openpyxl.utils.column_index_from_string(uppercase_part) - 1
  
    width = len(parse_excel_row(rows[0]))
    
    tag = [0]*(width+1)     
    for cb in col_boundarys:
        cb = cb - l
        if cb > width:    cb = width
        tag[cb] = 1
        if(cb-config.DELTA)>=1 and (cb+config.DELTA)<=width:
            for i in range(cb-config.DELTA, cb+config.DELTA+1):
                tag[i] = 1
        if(cb-config.DELTA)<1 and (cb+config.DELTA)<=width:
            for i in range(1, cb+config.DELTA+1):
                tag[i] = 1
        if(cb-config.DELTA)>=1 and (cb+config.DELTA)>width:
            for i in range(cb-config.DELTA, width+1):
                tag[i] = 1
        if(cb-config.DELTA)<1 and (cb+config.DELTA)>width:
            for i in range(1, width+1):
                tag[i] = 1
    res = []
    for row in rows:
        r = parse_excel_row(row)
        row_res = ""
        
        for j in range(0,len(r)):
            if(tag[j+1]==1):
                row_res += "|" + r[j]
        res.append(row_res)
    return res
        
def txt_compress(prompt, row_boundarys, col_boundarys):
    
    cell_rows = get_table_cell_input(prompt)

    result = row_delete(cell_rows,row_boundarys)
    result = col_delete(result,col_boundarys)
    
    return result

def nfs_compress(prompt, row_boundarys, col_boundarys):
    
    nfs_rows = get_table_nfs_input(prompt)

    result = row_delete(nfs_rows,row_boundarys)
    result = col_delete(result,col_boundarys)
    
    return result

def get_row_tag(rows):
    
    def check_row_space(row):
        cells = parse_excel_row(row)
        
        for cell in cells:
            cell_value = cell.split(",",1)[1]
            if cell_value != "":
                return False
        return True  
    
    tag = [0]*len(rows)
    for i in range(0,len(rows)):
        if check_row_space(rows[i]):
            tag[i] = 1 
            
    return tag

def get_col_tag(rows):

    def check_col_space(input):
        tag = [1] * len(input[0])
        for j in range(0,len(input[0])):
            for i in range(0,len(input)):
                if input[i][j] != " ":
                    tag[j] = 0
        return tag
    
    input = []
    for row in rows:
        cells = parse_excel_row(row)
        new_cells = []
        
        for cell in cells:
            new_cells.append(cell.split(",",1)[1])
        
        input.append(new_cells)
    return check_col_space(input)

def find_consecutive_ones_intervals(nums):
        
        intervals = []
        
        start = None
        count = 0
        for i, num in enumerate(nums):
         
            if num == 1:
              
                if start is None:
                    start = i
                
                count += 1
            
            if num != 1 or i == len(nums) - 1:
                
                if count >= 2:
                    
                    end = i if num != 1 else i + 1
                    intervals.append((start, end - 1))
                
                start = None
                count = 0
        return intervals

def coordinate_rearrangement(rows):
    new_rows = []
    my_dic_r = { }
    my_dic = { }
    for i in range(0,len(rows)):
        
        row_list = parse_excel_row(rows[i])
        
        new_row_list = []

        for j in range(0,len(row_list)):
            
            cell = row_list[j].split(',',1)
            cell_address = cell[0]
            cell_value = cell[1]

            number = str(i+1)
           
            after_address = col_num_to_letter(j+1) + number
           
            before_address = cell_address
            my_dic[after_address] = before_address
            my_dic_r[before_address] = after_address
            new_cell = after_address +"," + cell_value
            new_row_list.append(new_cell)
        
        s = ""
        for n in new_row_list:
            s += "|" + n
        new_rows.append(s)
    
    return new_rows, my_dic, my_dic_r

def groundtruth_coordinate_rearrangement(my_dic_r, labels):
    # pattern = r"'(.*?)'"
    # matches = re.findall(pattern, labels)
    # for i in range(0,len(matches)):
    #     if i%2 != 0:
    #         add = matches[i]
    #         adds = add.split(":")
    #         new_add = my_dic_r[adds[0]] + ":" + my_dic_r[adds[1]] 
    #         matches[i] = new_add
    # j = 0
    # ans = ""
    # while(j<len(matches)):
    #     ans += "{'" + matches[j] + "'" + ":" + "'" + matches[j+1] + "'}" + ", "
    #     j += 2
    # new_labels = "[" + ans[:-2] + "]"
    # return new_labels
    for label in labels:
        for i in range(0,len(label)):
            adds = label[i].split(":")
            new_add = my_dic_r[adds[0]] + ":" + my_dic_r[adds[1]] 
            label[i] = new_add
    return labels

def delete_space(txt_rows, nfs_rows):

    row_tag = get_row_tag(txt_rows)
    row_intervals = find_consecutive_ones_intervals(row_tag)
    for row_interval in row_intervals:
            if row_interval[0] == 0 or row_interval[1] == len(row_tag)-1:   continue
            else: 
                
                row_tag[row_interval[0]] = 2
                row_tag[row_interval[1]] = 2


    col_tag = get_col_tag(txt_rows)
    col_intervals = find_consecutive_ones_intervals(col_tag)
    for col_interval in col_intervals:
        if col_interval[0] == 0 or col_interval[1] == len(col_tag) - 1: continue
        else:
            
            col_tag[col_interval[0]] = 2
            col_tag[col_interval[1]] = 2

    new_cell_rows = []
    for i in range(0, len(txt_rows)):
        if row_tag[i]!=1:
            new_cell_rows.append(txt_rows[i])
    
    new_nfs_rows=[] 
    for j in range(0, len(nfs_rows)):
        if row_tag[j]!=1:
            new_nfs_rows.append(nfs_rows[j])


    return new_cell_rows, new_nfs_rows

def change_length(data):
    enc = tiktoken.get_encoding("cl100k_base")
    length = 0
    length += len(enc.encode(data["messages"][0]["content"])) + len(enc.encode(data["messages"][1]["content"])) 
    if length < 4096 - 250:
        return "<4k"
    elif length >= 4096 - 250 and length < 32768 - 250:
        return "4-32k"
    elif length >= 32768 -250:
        return "32k"     

def compress_layout(input_file_path,output_file_path,mapping_file_path): 

    count = 0
    with open(input_file_path,'r',encoding='utf-8') as input_file, open(output_file_path,'w',encoding='utf-8') as output_file, open(mapping_file_path,'w',encoding='utf-8') as mapping_file:
        for line in input_file:
            try:
                data = json.loads(line)
                prompt = data["messages"][1]["content"]
                file_name = data["file_name"]
                labels = data["messages"][2]["content"]
                row_boundary = []
                col_boundary = []

                
                adds_pattern = r"([A-Z]+)(\d+)"
                for label in labels:
                    for i in range(0,len(label)):
                        adds = label[i].split(":")
                        adds_matches0 = re.match(adds_pattern, adds[0])
                        if adds_matches0:
                            row_boundary.append(int(adds_matches0.group(2)))
                            col_boundary.append(openpyxl.utils.column_index_from_string(adds_matches0.group(1)))
                        adds_matches1 = re.match(adds_pattern, adds[1])
                        if adds_matches1:
                            row_boundary.append(int(adds_matches1.group(2)))
                            col_boundary.append(openpyxl.utils.column_index_from_string(adds_matches1.group(1)))
                
                # adds_pattern = r"([A-Z]+)(\d+)"
                # lebels_pattern = r"'(.*?)'"
                # labels_matches = re.findall(lebels_pattern, labels)

                # for i in range(0,len(labels_matches)):
                #     if i%2 != 0:
                #         add = labels_matches[i]
                #         adds = add.split(":")
                #         adds_matches0 = re.match(adds_pattern, adds[0])
                #         if adds_matches0:
                #             row_boundary.append(int(adds_matches0.group(2)))
                #             col_boundary.append(openpyxl.utils.column_index_from_string(adds_matches0.group(1)))
                #         adds_matches1 = re.match(adds_pattern, adds[1])
                #         if adds_matches1:
                #             row_boundary.append(int(adds_matches1.group(2)))
                #             col_boundary.append(openpyxl.utils.column_index_from_string(adds_matches1.group(1)))

               
                boundary_list = find_boundary(file_name)
                if len(boundary_list) == 0:
                    print("!!!!!!!!!!!!")
                    output_file.write(json.dumps(data) + '\n')
                else:
                    if len(boundary_list) != 0:
                    
                        for boundary in boundary_list:
                            list = boundary.split(",")
                            row_boundary.append(int(list[0].strip()))
                            row_boundary.append(int(list[1].strip()))
                            col_boundary.append(int(list[2].strip()))
                            col_boundary.append(int(list[3].strip()))
                    row_boundarys = sorted(set(row_boundary))
                    col_boundarys = sorted(set(col_boundary))
                    
                    txt_res = txt_compress(prompt, row_boundarys, col_boundarys)
                    nfs_res = nfs_compress(prompt, row_boundarys, col_boundarys)
                 
                    txt_res, nfs_res = delete_space(txt_res, nfs_res)
                    if len(txt_res) != len(nfs_res) :
                        raise MyCustomException("1条件满足，抛出异常")
                
                    new_txt_rows, my_dic, my_dic_r = coordinate_rearrangement(txt_res)
                    new_nfs_rows, _, _, = coordinate_rearrangement(nfs_res)
                    if len(new_txt_rows) != len(new_nfs_rows) :
                        raise MyCustomException("2条件满足，抛出异常")
                    prompt_change = rows_to_new_input(new_txt_rows, new_nfs_rows)                
                    new_labels = groundtruth_coordinate_rearrangement(my_dic_r,labels) 
                    data["messages"][1]["content"] = prompt_change
                    data["messages"][2]['content'] = new_labels
                    data["now_length"] = change_length(data)
                    output_file.write(json.dumps(data) + '\n')
                    count+=1
                    print(count)
                    item = {
                    "file_name": file_name,
                    "reflection": my_dic,
                    "reflection_r": my_dic_r,
                    }   
                    mapping_file.write(json.dumps(item) + '\n')
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON: {str(e)}")
                print(file_name)
                continue
            except Exception as e:
                print(f"Error processing data: {str(e)}")
                print(file_name)
                continue

if __name__ == "__main__":
    
    compress_layout(config.INPUT_FILE_PATH,
                    config.OUTPUT_FILE_PATH,
                    config.MAPPING_FILE_PATH
                    )


