import json
import re
    
with open('../Task_Data/vh_config.json', 'r') as f:
    config = json.load(f)


def cot(low_output):
    VH_SKILL = config["VH_SKILL"]

    for idx, data in enumerate(low_output):
        low_answer= data["answer"]

        actions_scores = {}
        # Regular expression to capture score lines and their respective actions
        score_regex = re.compile(r"(-?\d+) points?:.*?\[([^\]]+)\]")
        matches = score_regex.findall(low_answer)

        for score, actions in matches:
            # Split actions by comma
            for action in actions.split(','):
                action = action.strip()

                if len(action.split()) == 4:
                    put, obj1, how, obj2 = action.split()
                    if how == "in":
                        action = f"{put} {how} {obj1} {obj2}"
                    else: 
                        action = f"{put} {obj1} {obj2}"

                if action and action in VH_SKILL:  # Ensure it is not an empty string
                    actions_scores[action] = int(score)
                    
        for vh_skill in VH_SKILL:
            if vh_skill not in actions_scores.keys():
                if "find" in vh_skill:
                    actions_scores[vh_skill] = -1
                else:
                    actions_scores[vh_skill] = -2

        # Extract relevant objects
        relevant_objects_match = re.search(r"Relevant objects: (.+)", low_answer)
        if relevant_objects_match:
            relevant_objects = relevant_objects_match.group(1).strip()

        low_output[idx]["answer"] = actions_scores
        low_output[idx]["rationale"] = relevant_objects
        
        
        
    return low_output

def naive(low_output):
    VH_SKILL  = config["VH_SKILL"]
    for idx, data in enumerate(low_output):
        
        
        line = data["answer"]
        action = data["skill"]

        parts = line.split('\n')
        score_part = parts[0].split(':')
        
        if len(score_part) == 1:
            if "point" in score_part[0].lower():
                score_part[0] = score_part[0].replace("point", "")
                
            score_part[0] = re.sub(r"\d+\.\s*", "", score_part[0])

            try:
                score = int(score_part[0].strip())
            except:
                if score_part[0].strip() == "":
                    score = -2
                print(data)
        elif len(score_part) == 2:
            
            if action in line:
                print(line)
                print("=====")
                print(score_part)
                score = int(re.findall(r'\d+', score_part[0])[0])
            
            else:
                print(line)
                score = int(re.findall(r'\d+', score_part[1])[0])            
            
        if len(parts) > 1:
            relevant_objects_part = parts[1].split(':')
    
        if len(relevant_objects_part) == 1:
            relevant_objects = relevant_objects_part[0].strip().split(', ')

        else:
            relevant_objects = relevant_objects_part[1].strip().split(', ')
        
        if "put" in action:
            put, obj1, how, obj2 = action.split()
            if how == "on":
                action = f"{put} {obj1} {obj2}"
            else:
                action = f"{put} {how} {obj1} {obj2}"
                
        if action not in VH_SKILL:
            pass
        else:

            # 결과 리스트에 추가
            low_output[idx]["rationale"] = relevant_objects
            low_output[idx]["answer"] = score
            low_output[idx]["action"] = action

     
    return low_output
    
def processing(low_output):
    
    VH_SKILL_FOR_PROMPT = config["VH_SKILL_FOR_PROMPT"]
    VH_SKILL  = config["VH_SKILL"]

    for idx, data in enumerate(low_output):
        answer = {}
        relevant_objects = ""

        data["answer"] = data["answer"].lower()

        if "robot: " in data["answer"]:
            
            data["answer"] = data["answer"].replace('robot: ', '')
            
        lines = data["answer"].split("\n")
        for line in lines:
            if re.match(r".*: -?\d+", line):
                print(line)
                action, score = line.split(": ")
                
                for vh_skill in VH_SKILL_FOR_PROMPT:
                    if vh_skill not in answer.keys():
                        if "find" in vh_skill:
                            answer[vh_skill] = -1
                        else:
                            answer[vh_skill] = -2

                try:
                    answer[action.strip()] = int(score.strip())
                except: pass



            elif line.startswith("relevant objects:"):
                relevant_objects = line.split("relevant objects:")[1].strip()
                
                
        new_dict =  {}

        for key, val in answer.items():
                
            if len(key.split()) == 4:
                put, obj1, how, obj2 = key.split()
                if how == "in":
                    key = f"{put} {how} {obj1} {obj2}"
                else: 
                    key = f"{put} {obj1} {obj2}"



            if key in VH_SKILL:
                new_dict[key] = val

        low_output[idx]["answer"] = new_dict
        low_output[idx]["rationale"] = relevant_objects

    return low_output

  
  
import argparse
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="palm")
    parser.add_argument(
        "--prompt", type=str, default="icl"
    )   
    args = parser.parse_args()
    
    with open(f'output/{args.prompt}.json', 'r') as f:
        low_output = json.load(f)
    
    if args.prompt == "cot":
        output= cot(low_output)

    if args.prompt == "naive":
        output= naive(low_output)
        
    else:
        output= processing(low_output)

    with open(f'output/{args.prompt}_processing.json', 'w') as f:
        json.dump(output, f, indent=4) 