import json
import re
import sys
import config
import CaseJudge
import openpyxl
from utils import get_table_cell_input
from utils import get_table_nfs_input
from utils import parse_excel_row_value
from utils import tuple_to_excel_cell
from utils import excel_address_to_coords
from utils import get_table_fmt_input
from utils import parse_excel_row


sys.setrecursionlimit(10000)



class MyCustomException(Exception):
    pass


class TableDataAggregation:

    def __init__(self, input_file_path, output_file_path):
        self.input_file_path = input_file_path
        self.output_file_path = output_file_path

 
    @staticmethod
    def check_and_process_string(input_str):
        
        if all(char.isdigit() or char == ',' or char == '.' or (char in '+-' and input_str.index(char) == 0) for char in input_str):
            return input_str.replace(',', '')

        else:
            return input_str

    
    def identify_number(self, s):
        if re.match(r'^-?\d+$', s):
            return 1
        elif re.match(r'^-?\d+\.\d+$', s):
            return 2
        elif CaseJudge.is_percentage(s):
            return 3
        elif CaseJudge.is_scientific_notation(s):
            return 4
        elif CaseJudge.is_date(s):
            return 5
        elif CaseJudge.is_time(s):
            return 6
        elif CaseJudge.is_currency(s):
            return 7
        elif CaseJudge.is_email(s):
            return 8
        else:
            return 9


    def get_type(self, nfs_cell, cell):

        if cell == "":
            return -1
        if config.NFS_TAG: 
            
            if nfs_cell == "None":
                year_pattern = r'^19\d{2}$|^20\d{2}$'

                if bool(re.match(year_pattern, cell)):  
                   r = 0
                
                else:
                    r = self.identify_number(self.check_and_process_string(cell))
                    if r == 9 :
                        r = cell
                return r
            else:
                return nfs_cell
        
        else:   
            if nfs_cell in config.DIC:
                return config.DIC[nfs_cell]
            
            year_pattern = r'^19\d{2}$|^20\d{2}$'
            if bool(re.match(year_pattern, cell)):
                r = 0
            
            else:
                r = self.identify_number(self.check_and_process_string(cell))

                if r == 9:
                    r = cell    
                    # if nfs_cell == "None":
                    #     r = cell
                    # else:
                    #     r = nfs_cell
                    
            return r


    def aggregate_similar_areas(self, input, nfs_input):
        rows, cols = len(input), len(input[0]) if input else 0
        visited = [[False for _ in range(cols)] for _ in range(rows)]
        

        def is_valid(r, c, val_type):
            if not (0 <= r < rows and 0 <= c < cols and not visited[r][c]):
                return False
            type = self.get_type(nfs_input[r][c], input[r][c])

            if val_type == type:
                tag = True
            else:
                tag = False
            
            return 0 <= r < rows and 0 <= c < cols and not visited[r][c] and tag

        def dfs(r, c, val_type):

            if not is_valid(r, c, val_type):
                return [r, c, r-1, c-1]  

            visited[r][c] = True
            bounds = [r, c, r, c]  

            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:  
                new_r, new_c = r + dr, c + dc
                if is_valid(new_r, new_c, val_type):
                    new_bounds = dfs(new_r, new_c, val_type)
                    
                    bounds[0], bounds[1] = min(bounds[0], new_bounds[0]), min(bounds[1], new_bounds[1])
                    bounds[2], bounds[3] = max(bounds[2], new_bounds[2]), max(bounds[3], new_bounds[3])

            return bounds
        areas = []
        for r in range(rows):
            for c in range(cols):
                if not visited[r][c]:
                    if input[r][c]!="":
                        val_type = self.get_type(nfs_input[r][c],input[r][c])
                        bounds = dfs(r, c, val_type)
                        if bounds[0] <= bounds[2] and bounds[1] <= bounds[3]:  
                            areas.append(((bounds[0], bounds[1]), (bounds[2], bounds[3]), val_type))
        
        return areas


    def change_data(self, begin, end, val_type, input):
        begin_position = excel_address_to_coords(begin)
        end_position = excel_address_to_coords(end)

        tag = val_type

        for i in range(0,len(input)):

            if i>=begin_position[0] and i <= end_position[0]:
                for j in range(0,len(input[0])):
                    if j >= begin_position[1] and j <= end_position[1]:
                        cell_address = input[i][j].split(",",1)[0]
                        cell_value = input[i][j].split(",",1)[1]
                        if isinstance(tag, str):
                            new_value = tag
                        else:
                            if tag == 0:
                                new_value = "YearData"
                        
                            elif tag == 1:
                                new_value = "IntNum"

                            elif tag == 2:
                                new_value = "FloatNum"
                       
                            elif tag == 3:
                                new_value = "PercentageNum"
                        
                            elif tag == 4:
                                new_value = "SentificNum"

                            elif tag == 5:
                                new_value = "DateData"
                        
                            elif tag == 6:
                                new_value = "TimeData"
                       
                            elif tag == 7:
                                new_value = "CurrencyData"
                        
                            elif tag == 8:
                                new_value = "EmailData"
                        
                            else:
                                new_value = cell_value
                        input[i][j] = cell_address + "," + new_value
        return input


    def get_new_input(self, input, fmt = None):
        new_rows = []
        for i in range(0,len(input)):
            new_row = ""
            for j in range(0,len(input[0])):
                new_row += "|" + input[i][j]
            new_rows.append(new_row)
    
        input_change = "\nInput: "
        for k in range(0,len(new_rows)):
            input_change += new_rows[k] + "\n"
        
        input_change += "\n\n###\n\n"

    
        return input_change


    def process_file(self):
        with open(self.input_file_path, 'r', encoding='utf-8') as input_file, open(self.output_file_path, 'w', encoding='utf-8') as output_file:
            for line in input_file:
                try:
                    data = json.loads(line)
                    prompt = data["messages"][1]["content"]
                    
                    cell_rows = get_table_cell_input(prompt)
                    nfs_rows = get_table_nfs_input(prompt)
                    
                    fmt = None
                    if config.FMT1_TAG or config.FMT3_TAG:
                        fmt = get_table_fmt_input(prompt)

                    cell_input = []
                    for row in cell_rows:
                        cell_input.append(parse_excel_row_value(row))
                    
                    nfs_input = []
                    for row in nfs_rows:
                        nfs_input.append(parse_excel_row_value(row))
                
                    

                    areas = self.aggregate_similar_areas(cell_input, nfs_input)
                    
                    first_address = parse_excel_row(cell_rows[0])[0].split(",", 1)[0]
                    matches = re.match(r"([A-Z]+)(\d+)", first_address)
                    uppercase_part = matches.group(1)
                    numeric_part = matches.group(2)
                    row_l = int(numeric_part) - 1
                    col_l = openpyxl.utils.column_index_from_string(uppercase_part) - 1



                    new_areas = []
                    new_areas1 = []

                    for area in areas:
                        begin_num_index = [area[0][0]+row_l, area[0][1] + col_l]
                        end_num_index = [area[1][0]+row_l, area[1][1] + col_l]
                        begin_address = tuple_to_excel_cell(begin_num_index)
                        end_address = tuple_to_excel_cell(end_num_index)

                        new_areas1.append((begin_address + ":" + end_address, area[2]))
                        if begin_address != end_address and area[2] != 9:
                            new_areas.append((begin_address + ":" + end_address, area[2]))
                    
                    data["areas"] = new_areas
                    data["areas1"] = new_areas1
               
                    cell_input = []
                    for row in cell_rows:
                        cell_input.append(parse_excel_row(row))
                    for area in new_areas:
                        begin_address = area[0].split(":")[0]
                        end_addresss = area[0].split(":")[1]
                        cell_input = self.change_data(begin_address,end_addresss,area[1],cell_input)
                    
                    new_input = self.get_new_input(cell_input,fmt=fmt)
                    

                    new_prompt = new_input
                    
                    data["messages"][1]["content"] = new_prompt
                    output_file.write(json.dumps(data) + '\n')
                    
                except RecursionError:
                    print("RecursionError: " + data["file_name"])
                    continue  
                except Exception as e:
                    print(f"An error occurred on line : {e}, file_name: " + data["file_name"])
                    continue


if __name__ == "__main__":

    compressor = TableDataAggregation(config.AGGREGATION_INPUT_FILE_PATH, 
                                      config.AGGREGATION_OUTPUT_FILE_PATH)
    compressor.process_file()