import json
import argparse
from collections import Counter, defaultdict
import google.generativeai as genai
from itertools import groupby, product
from operator import itemgetter
import os
import natsort

from sentence_transformers import SentenceTransformer, util

with open('../Task_Data/vh_config.json', 'r') as f:
    config = json.load(f)
VH_SKILL = config["VH_SKILL"]


SAFETY_SETTINGS = [
            {
            "category": "HARM_CATEGORY_HARASSMENT",
            "threshold": "BLOCK_NONE"
            },
            {
            "category": "HARM_CATEGORY_HATE_SPEECH",
            "threshold": "BLOCK_NONE"
            },
            {
            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "threshold": "BLOCK_NONE"
            },
            {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_NONE"
            }
        ]


def load_reward_file(prompt):
    with open(f'output/{prompt}.json', 'r') as f:
        rewards = json.load(f)
    
    return rewards

def load_reward_dir_files(file_dir):
    top_dir = f"{file_dir}"
    files = os.listdir(top_dir)
    total = []

    for file in files:
        if "json" not in file:
            pass
        else:
            file = os.path.join(top_dir, file)
            with open(file, 'r') as f:
                data = json.load(f)
                total.append(data)

    return total


def temporal_prompt1(ins, skill_plan):

    sequence = ""
    for idx, skill in enumerate(skill_plan):
        
        if len(skill.split()) == 4:
            put, how, obj1, obj2 = skill.split()
            skill = f"{put} {obj1} {how} {obj2}"
            
        elif "put" in skill:
            put, obj1, obj2 = skill.split()
            skill = f"{put} {obj1} on {obj2}"
        sequence += f"{idx+1}. {skill}, "

    sequence = sequence[:-2]


    return \
f"""Action List: close bathroomcabinet, close fridge, close microwave, find bathroomcounter, find bathtub, find bed, find bookshelf, find closetdrawer, find coffeetable, find desk, find fridge, find kitchentable, find microwave, find sink, find sofa, find toaster, find wallshelf, grab apple, grab bananas, grab book, grab breadslice, grab cat, grab cellphone, grab cereal, grab clothespants, grab creamybuns, grab pillow, grab salmon, grab toothbrush, grab toothpaste, open bathroomcabinet, open fridge, open microwave, put apple in fridge, put apple in microwave, put apple on breadslice, put apple on sink, put bananas in fridge, put bananas on breadslice, put bananas on sink, put breadslice on kitchentable, put cat on bathtub, put cat on bed, put cat on desk, put cellphone in fridge, put cereal in fridge, put cereal on kitchentable, put clothespants on bathroomcounter, put clothespants on bed, put creamybuns in fridge, put creamybuns on kitchentable, put salmon in fridge, put salmon in microwave, put toothpaste in bathroomcabinet, sit bed, sit sofa, switch on microwave, switch on toaster
From the list of actions provided above, I selected several to form an action sequence like "{sequence}". If this sequence of actions is executed in order, is it possible to achieve "{ins}"?
Answer only true or false.
"""

def temporal_prompt2(ins, skill_plan):
    sequence = ""
    for idx, skill in enumerate(skill_plan):
        sequence += f"{idx+1}. {skill}, "

    sequence = sequence[:-2]

    return \
f"""Choose the wrong order to achieve {ins} on {skill_plan}. Answer only number."""

def temporal_ensemble(prompt):
    genai.configure(api_key='AIzaSyCHGBqvOcEtq34AhNjWCoH-682e_ZEOFH8')
    model = genai.GenerativeModel('gemini-pro')

    backward_rewards = load_reward_file(prompt)
    backward_reward = sorted(backward_rewards, key=lambda x: (x['task_id'], len(x['history'])))
    grouped_data = groupby(backward_reward, key=itemgetter('task_id'))

    for task_id, item in grouped_data:
        planning = []
        planning_idx = []
        total = []
        for line in item:
            ins = line["instruction"]
            answer = line["answer"]
            his = line["history"]
            answer_top = [key for key, value in answer.items() if value == 2 and key in VH_SKILL]
            if len(answer_top) ==0:
                answer_top = ["nothing"]
            planning.append(answer_top)
            planning_idx.append({"task_id": task_id, "len_history": len(his)})
            count += 1
        
        total = product(*planning)
        all_planning = [list(comb) for comb in total]
        for subplan in all_planning:
            prompt_text = temporal_prompt1(ins,  subplan)
            completion = model.generate_content(
                    prompt_text,
                    generation_config=genai.types.GenerationConfig(
                    temperature=0),
            )
            answer = completion.text

            if answer.lower() != "true":
                choice_prompt = temporal_prompt2(ins, subplan)
                # print(choice_prompt)
                completion2 = model.generate_content(
                    choice_prompt,
                    generation_config=genai.types.GenerationConfig(
                    temperature=0),
                    safety_settings = SAFETY_SETTINGS,
                )
                answer2 = completion2.text
                if "," in answer2:
                    answer2 =  [int(item.strip())-1 for item in answer2.split(',')]
            
                else: answer2 = [int(answer2) - 1]
            
                for i in answer2:
                    rem_action = subplan[i]
                    check = planning_idx[i]
                    for id, line in enumerate(backward_reward):
                        if line["task_id"] == check["task_id"] and len(line["history"]) == check["len_history"]:
                            backward_reward[id]["answer"][rem_action] = -999
            
            else :
                pass
            
    return backward_rewards

def voting(data):
    for idx, task in enumerate(data): # one line
        if len(list(task["answer"].keys())) != len(VH_SKILL):
            for skill in list(task["answer"].keys()):
                if skill not in VH_SKILL:
                    del task["answer"][skill]
        new_answer = {}
        for key, values in task["answer"].items():
            
            if key not in VH_SKILL:
                pass
            else:
                most_common_value, most_common_count  = Counter(values).most_common(1)[0]
                new_answer[key] = most_common_value
                
        data[idx]["answer"] = new_answer
    return data

def aggregate_answers(data):
    grouped_data = defaultdict(list)
    for item in data:
        key = (item['task_id'],  tuple(item['history']), )
        grouped_data[key].append(item)
    results = []
    for items in grouped_data.values():
        if len(items) == 1:
            results.append(items[0])
        else:
            merged_item = items[0].copy()
            answer = defaultdict(list)

            for item in items:
                for action, value in item['answer'].items():
                    answer[action].append(value)
            merged_item['answer'] = dict(answer)
            results.append(merged_item)

    return results

def contextual_ensemble(file_dir):
    reward_files = load_reward_dir_files(file_dir)
    total = []
    for file in reward_files:
        total += file
    result = aggregate_answers(total)
    prompt_ensmeble_reward = voting(result)

    return prompt_ensmeble_reward

def structural_ensemble(rationale_expert, prompt):
    rag_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
    reward_file= load_reward_file(prompt)
    new_data = []
    
    for line in reward_file:
        ins = line["instruction"]
        rationale = line["rationale"]
        if "none" in rationale:
            rationale = "none"
        if "," in rationale:
            temp_rationale = natsort.natsorted(rationale.split(","))
            rationale = str(rationale)
        
        # observation = line["visible"]
        history = line["history"]
        
        for ration in rationale_expert:
            ration_ins = ration["instruction"]
            # ration_obs = ration["visible"]
            ration_his = ration["history"]
            if ration_ins == ins and ration_his == history:            
                rationale_label = ration["answer"]
                
                
                if "," in rationale_label:
                    temp_rationale = natsort.natsorted(rationale_label.split(","))
                    rationale_label = str(temp_rationale)
                
                break
            
        pred_emb = rag_model.encode(rationale)
        label_emb = rag_model.encode(rationale_label)
        distance = util.pytorch_cos_sim(label_emb, pred_emb)[0][0]
    
        if distance >= 0.5:
            new_data.append(line)

    return  new_data

def majority_ensemble(file_dir):
    reward_files = load_reward_dir_files(file_dir)
    total = []
    for file in reward_files:
        total += file
    result = aggregate_answers(total)
    result = sorted(result, key=lambda x: (x['task_id'], len(x['history'])))
    prompt_ensmeble_reward = voting(result)

    return prompt_ensmeble_reward


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, required=True, choices=["contextual", "temporal", "structural", "majority", "ensemble"])
    parser.add_argument("--prompt", type=str, help="Prompt for the consistency method")
    parser.add_argument("--ensemble", type=str, help="Prompt for the ensemble method")

    args = parser.parse_args()

    if args.method == "contextual":
        file_dir = "YOUR FILE DIR"
        ensemble_reward = contextual_ensemble(file_dir)
        output_path = 'output/contextual.json'
        
    elif args.method == "temporal":
        if not args.prompt:
            raise ValueError("The --prompt argument is required for the temporal method.")
        ensemble_reward = temporal_ensemble(args.prompt)
        output_path = f'output/backward_{args.prompt}.json'
        
    elif args.method == "structural":
        if not args.prompt:
            raise ValueError("The --prompt argument is required for the structural method.")
        with open('Data/MDP_dataset.json', 'r') as f:
            rationale = json.load(f)
        ensemble_reward = structural_ensemble(rationale, args.prompt)
        output_path = f'output/structural_{args.prompt}.json'
        
    elif args.method == "ensemble":
        if not args.ensemble:
            raise ValueError("The --method argument is required for the majority method.")
        file_dir = "YOUR FILE DIR"  # majority voting all reward
        ensemble_reward = majority_ensemble(file_dir)
        output_path = f'output/{args.ensemble}.json'

    with open(output_path, 'w') as f:
        json.dump(ensemble_reward, f, indent=4)

if __name__ == "__main__":
    main()