import torch
from base import simulator
import random
import torch
import os
from sentence_transformers import SentenceTransformer
import torch
class msgr_simulator(simulator):
    
    def __init__(self):
        super().__init__()
        
    def run_episode_ground_truth_pathSolver(self, episode,newTask=False):
        if self.done:
            return
        for i in range(episode):
            with torch.no_grad():
                currentState = self.observationProcessor.generate_state(self.training_env)
                self.stateContainer.append(self.observationProcessor.simplifyState(currentState))
                self.pathSolver.update(currentState)
                action = self.pathSolver.get_action()
                self.trajectory["state"].append(self.observationProcessor.generate_trajectory_state(self.obs))
                self.envContainer.append(self.envCopier.deep_copy(self.training_env,newTask))
                self.trajectory["action"].append(action)
                if self.verbose:
                    print("Step: ", self.trainingStep+1)
                    print("Map:")
                    self.pathSolver.print_map()
                    print(self.observationProcessor.generate_grid(self.obs))
                    print(currentState)
                    actionList=["w","s","a","d",""]
                    print("action: ",actionList[action])
                self.obs, self.reward, self.done, _ = self.training_env.step(action)                # Calcualte new current state, and generate reward based on this state
                currentState = self.observationProcessor.generate_state(self.training_env)
                if self.verbose:
                    print(currentState)
                    print("After Step")
                    print(self.observationProcessor.generate_grid(self.obs))
                self.reward=self.reward*100+self.observationProcessor.process_reward(currentState)
                self.totalReward=self.totalReward+self.reward
                self.trajectory["reward"].append(self.reward)
                # self.subgoal.append(self.observationProcessor.generateSubgoal(currentState))
                self.trainingState.append(self.observationProcessor.generate_state(self.training_env))
                self.trainingStep=self.trainingStep+1
                if self.done:
                    break

    def run_episode_EMMA_train(self, episode,newTask=False): 
        if self.done:
            return
        # self.buffer.reset(self.obs)
        for i in range(episode):
            # with torch.no_grad():
            #     action = self.model(self.buffer.get_obs(), self.manual)
            self.trajectory["state"].append(self.observationProcessor.generate_trajectory_state(self.obs))
            self.envContainer.append(self.envCopier.deep_copy(self.training_env,newTask))
            action=random.randint(0,4)
            self.trajectory["action"].append(action)
            if self.verbose:
                print("Step: ", self.trainingStep+1)
                print(self.observationProcessor.generate_grid(self.obs))
            currentState=self.observationProcessor.generate_state(self.training_env)
            self.stateContainer.append(
                self.observationProcessor.simplifyState(
                    currentState
                )
            )
            # print(currentState)
            self.obs, self.reward, self.done, _ = self.training_env.step(action)
            if self.verbose:
                print(self.observationProcessor.generate_grid(self.obs))
            # self.buffer.update(self.obs)
            currentState=self.observationProcessor.generate_state(self.training_env)
            # print(currentState)
            self.reward=self.reward*100+self.observationProcessor.process_reward(currentState)
            self.totalReward=self.totalReward+self.reward
            self.trajectory["reward"].append(self.reward)
            # self.subgoal.append(self.observationProcessor.generateSubgoal(currentState))
            self.trainingState.append(currentState)
            self.trainingStep=self.trainingStep+1
            if self.done:
                break
            
    def run_trajectory(self,episode,newTask=False):
        self.nonExpertTimes=0
        self.nonExpertTick=0
        stride=2
        random_number=random.random()
        self.run_episode_ground_truth_pathSolver(50,newTask)
        # print("debug: ",random_number)
        # if (random_number<self.args.entire_expert_probability):
        #     print("This is entire expert trajectory")
        #     self.run_episode_ground_truth_pathSolver(50,newTask)
        # else:
        #     for _ in range(episode):
        #         if(self.done):
        #             break
        #         if(self.nonExpertTick<5 and self.random_training()):
        #             self.nonExpertTimes+=stride
        #             self.nonExpertTick+=1
        #             self.run_episode_EMMA_train(stride,newTask)
        #             self.run_episode_ground_truth_pathSolver(6,newTask)
        #         else:
        #             self.run_episode_ground_truth_pathSolver(3,newTask)
        # # Add this step for the observation process to give reflection on the last step.
        # currentState=self.observationProcessor.generate_state(self.training_env)
        # self.stateContainer.append(
        #         self.observationProcessor.simplifyState(
        #             currentState
        #         )
        #     )
        # print("Total Reward: " ,self.totalReward)

def process_data_sentenceBert(trajectory,name,idx):
        rewards = trajectory["reward"]
        returns = []
        gamma=1
        cumulative_return = 0
        for reward in reversed(rewards):
            cumulative_return = reward + gamma * cumulative_return
            returns.append(cumulative_return)
        returns.reverse()
        trajectory["return_to_go"] = returns
        h,f,h_500,f_500,h_200,f_200,h_100,f_100=[],[],[],[],[],[],[],[]
        for text in trajectory["languages"][idx]:
            if(text=={}):
                text={"hindsight positive":"", "hindsight negative":"", "foresight positive":"","foresight negative":""}
            h.append(text["hindsight positive"]["1000"]["human"]+text["hindsight negative"]["1000"]["human"])
            f.append(text["foresight positive"]["1000"]["human"])
            h_500.append(text["hindsight positive"]["500"]["human"]+text["hindsight negative"]["500"]["human"])
            f_500.append(text["foresight positive"]["500"]["human"])
            h_200.append(text["hindsight positive"]["200"]["human"]+text["hindsight negative"]["200"]["human"])
            f_200.append(text["foresight positive"]["200"]["human"])
            h_100.append(text["hindsight positive"]["100"]["human"]+text["hindsight negative"]["100"]["human"])
            f_100.append(text["foresight positive"]["100"]["human"])
        l=h+f+h_500+f_500+h_200+f_200+h_100+f_100
        with torch.no_grad():  # No gradient is needed (inference mode)
            l_embedding=torch.tensor(model.encode(l)).reshape(8,len(trajectory["languages"][idx]),1,768)
            h_embedding=l_embedding[0]
            f_embedding=l_embedding[1]
            h_500_embedding=l_embedding[2]
            f_500_embedding=l_embedding[3]
            h_200_embedding=l_embedding[4]
            f_200_embedding=l_embedding[5]
            h_100_embedding=l_embedding[6]
            f_100_embedding=l_embedding[7]

        assert h_embedding.shape == (len(trajectory["languages"][idx]), 1, 768)
        manual_embedding = torch.unsqueeze(torch.tensor(model.encode(trajectory["manual"])), dim=0)
        trajectory["encoded_manual"]=manual_embedding
        trajectory["rhf_embedding"]=(h_embedding+f_embedding)/2
        trajectory["rh_embedding"]=h_embedding
        trajectory["rf_embedding"]=f_embedding
        trajectory["hf_embedding_500"]=(h_500_embedding+f_500_embedding)/2
        trajectory["hf_embedding_200"]=(h_200_embedding+f_200_embedding)/2
        trajectory["hf_embedding_100"]=(h_100_embedding+f_100_embedding)/2
        torch.save(trajectory,name)

if __name__ == "__main__":
    trajectory_dir = '/data/Dataset_Messenger_template'
    if not os.path.exists(trajectory_dir):
        os.makedirs(trajectory_dir)
    messengerSimulator = msgr_simulator()    
    model = SentenceTransformer("sentence-transformers/paraphrase-TinyBERT-L6-v2")
    model.eval()
    augment_number=1
    newTask=False
    expert_efficiency={}
    for i in range(1000):
        expert_efficiency[str(i)]=[]
    # for x in range(messengerSimulator.start,messengerSimulator.end):
    # for x in range(0,1):
    for x in range(0,500):
        try:
            seed=x%100
            # seed=x%5
            # seed=x
            messengerSimulator.reset(newTask=newTask,verbose=False,augment_number=augment_number,seed=seed)
            messengerSimulator.run_trajectory(20,newTask=newTask)
            messengerSimulator.generateLanguage()
            expert_efficiency[str(seed)].append((messengerSimulator.done,messengerSimulator.trainingStep))
            # messengerSimulator.trajectory["nonExpertTime"]=messengerSimulator.nonExpertTimes
            for i in range(augment_number):
                # messengerSimulator.trajectory_name=trajectory_dir+"/trajectory_"+str(x*augment_number+i+1)+".pth"
                # process_data_sentenceBert(messengerSimulator.trajectory,messengerSimulator.trajectory_name,i)
                trajectory_name=trajectory_dir+"/trajectory_"+str(x*augment_number+i+1)+".pth"
                # process_data_sentenceBert(messengerSimulator.trajectory,trajectory_name,i)
        except:
            print("Program Error")
    torch.save(expert_efficiency,"expert_trajectory_length.pth")        
