import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import Dataset ,DataLoader
from model import VLDT
from lightning.pytorch.loggers import WandbLogger
import wandb
from torch.nn.functional import pad
from data_process import getDataset
import numpy as np
from observation_process import observationProcessor
from evaluate import evaluate_episode_rtg
import argparse
import torch.nn.functional as F
import random
from sentence_transformers import SentenceTransformer
import os
class MultimodalTransformer(pl.LightningModule):
    def __init__(self,state_mean,state_std,variant):
        super().__init__()
        self.state_mean=state_mean
        self.state_std=state_std
        self.language_model = SentenceTransformer("sentence-transformers/paraphrase-TinyBERT-L6-v2")
        self.stateProcessor=observationProcessor()
        with torch.no_grad():
                self.encoded_empty_language=torch.tensor(self.language_model.encode("")).unsqueeze(dim=0).to(device='cuda')
        self.variant=variant
        model = VLDT(
            empty_language_embedding=self.encoded_empty_language,
            state_std=state_std,
            state_mean=state_mean,
            state_dim=self.variant["state_dim"],
            act_dim=self.variant["act_dim"],
            hidden_size=self.variant["embed_dim"],
            max_length=self.variant["K"],
            n_layer=self.variant["n_layer"],
            n_head=self.variant["n_head"],
            n_inner=4 * self.variant["embed_dim"],
            activation_function=self.variant["activation_function"],
            n_positions=1024,
            resid_pdrop=self.variant["dropout"],
            attn_pdrop=self.variant["dropout"],
            segment_length=self.variant["K"],
            roberta_encoder_len=self.variant["roberta_encoder_len"],
            category=self.variant["action_category"],
            env_name=self.variant["env_name"])
        self.decision_transformer = model  # Decision Transformer with trajectory input
        
    def create_cnn(self, cnn_channels):
        # This is a simple CNN architecture example. You might want to design your own.
        layers = []
        for in_channels, out_channels in zip(cnn_channels[:-1], cnn_channels[1:]):
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            layers.append(nn.ReLU())
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        return nn.Sequential(*layers)

    def forward(self, text_input, action,reward,state,language,return_to_go,timesteps,mask):
        # Forward pass through the decision transformer
        decision_output = self.decision_transformer(text_input,state, action, reward, return_to_go,timesteps,language,mask,self.variant["withLanguage"])
        return decision_output


        
    def training_step(self, batch, batch_idx):
        self.train()
        text_input, action,reward,state,language,return_to_go,timesteps,mask= batch['text'], batch["action"],batch["reward"],batch["state"],batch["language"],batch["return_to_go"],batch["time_stamp"],batch["mask"]  
        action_pred,_,_ = self.forward(text_input, action,reward,state,language,return_to_go,timesteps,mask)
        loss_function = nn.CrossEntropyLoss()
        loss = loss_function(action_pred.view(-1, self.variant["action_category"]), action.long().view(-1))
        self.log("training_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def on_train_epoch_end(self):
        self.eval()
        if (self.current_epoch+1) % self.variant["log_artifact_interval"] == 0:
            trajectory_dir=f'/data/checkpoint/messenger/{self.variant["artifact_name"]}'
            if not os.path.exists(trajectory_dir):
                os.makedirs(trajectory_dir)
            model_path=f'{trajectory_dir}/model_{self.current_epoch}.pth'
            torch.save(self.state_dict(), model_path)
        # if ((self.current_epoch+1)%self.variant["eval_interval"]==0 and self.current_epoch > 100):
        if ((self.current_epoch+1)%self.variant["eval_interval"]==0):
            if self.variant["env_name"]=="messenger":
                for probability in [100]:
                # for probability in [0,33,66,100]:
                    return_list=[]
                    length_list=[]
                    subgoal_completes=0
                    goal_completes=0
                    for seed in range(self.variant["num_eval_episodes"]):
                        episode_return, episode_length,subgoal_complete,goal_complete,_=evaluate_episode_rtg(
                            model=self.decision_transformer,
                            language_model=self.language_model,
                            max_ep_len=self.variant["max_ep_len"],
                            scale=self.variant["scale"],
                            state_mean=self.state_mean,
                            state_std=self.state_std,
                            device='cuda',
                            target_return=self.variant["target_return"],
                            mode='normal',
                            newTask=self.variant["newTask"],
                            hasLanguage=self.variant["withLanguage"],
                            empty_language=self.encoded_empty_language,
                            realLang=False,
                            eval_mode="validation",
                            eval_type="yes",
                            probability_threshold=probability,
                            seed=seed
                            # seed=seed%5
                        )
                        subgoal_completes+=subgoal_complete
                        goal_completes+=goal_complete
                        return_list.append(episode_return)
                        length_list.append(episode_length)
                    print(f"Language Probability {probability}, ave return: ",sum(return_list)/self.variant["num_eval_episodes"], " subgoal rate: ",subgoal_completes/self.variant["num_eval_episodes"]," goal rate: ",goal_completes/self.variant["num_eval_episodes"])                
                    self.log(f"Evaluation_ave_return with language probability {probability}", sum(return_list)/self.variant["num_eval_episodes"], logger=True)
                    self.log(f"Subgoal rate with language probability {probability}", subgoal_completes/self.variant["num_eval_episodes"], logger=True)
                    self.log(f"Success rate with language probability {probability}", goal_completes/self.variant["num_eval_episodes"], logger=True)

    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.variant["learning_rate"], weight_decay=self.variant["weight_decay"])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=self.variant["T_max"],  # Ensure this is set to the desired epochs, e.g., 15 for reset
            eta_min=self.variant["eta_min"]  # Make sure this is set to 1e-4 for the minimum learning rate
        )

        # PyTorch Lightning scheduler configuration
        scheduler_config = {
            'scheduler': scheduler,
            'interval': 'epoch',  # Scheduler steps every epoch
            'frequency': 1,  # Scheduler is applied every epoch
        }
        return [optimizer],[scheduler_config]
        
class MultimodalDataset(Dataset):
    def __init__(self, data,state_mean,state_std,variant):
        self.trajectories=data
        self.segmentLength=variant["K"]
        self.state_mean=state_mean
        self.state_std=state_std
        self.variant=variant
        self.language_model = SentenceTransformer("sentence-transformers/paraphrase-TinyBERT-L6-v2")
        with torch.no_grad():
                self.encoded_empty_language=torch.tensor(self.language_model.encode("")).unsqueeze(dim=0).unsqueeze(dim=0).to(device='cuda')
    
    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        trajectory=self.trajectories[idx]
        max_length=torch.tensor(trajectory["state"], device='cuda').shape[0]
        starting_point=0 if(max_length<=self.segmentLength) else torch.randint(0, max_length-2, (1,), device='cuda')[0]
        ending_point=starting_point+self.segmentLength if(starting_point+self.segmentLength<max_length-1) else max_length-1
        pad_length = self.segmentLength - (ending_point - starting_point)
        # For text and language, we can use the tokenizer's padding functionality directly
        text = trajectory["encoded_manual"]
        rewards = torch.tensor(trajectory["reward"][starting_point:ending_point], device='cuda').float().unsqueeze(1)
        if pad_length > 0:
            rewards = pad(rewards, (0, 0, pad_length, 0), "constant", 0)
        # Divide the scale
        rewards=rewards/self.variant["scale"]
        actions = torch.tensor(trajectory["action"][starting_point:ending_point], device='cuda').float().unsqueeze(1)
        if pad_length > 0:
            actions = pad(actions, (0, 0, pad_length, 0), "constant", 0)
        return_to_go = torch.tensor(trajectory["return_to_go"][starting_point:ending_point], device='cuda').float().unsqueeze(1)
        if pad_length > 0:
            return_to_go = pad(return_to_go, (0, 0, pad_length, 0), "constant", 0)

        # Divide the scale
        return_to_go=return_to_go/self.variant["scale"]

        mask = torch.zeros(self.segmentLength, device='cuda').bool()
        mask[-max_length:] = True  # We shift the 'True' values to the end part of the mask

        # For states, since it's a numpy array, you'll need to adjust the concatenation:
        states = torch.tensor(trajectory["state"][starting_point:ending_point], dtype=torch.float32).to(device='cuda')
        padding = torch.zeros(pad_length, states.shape[1], device='cuda',dtype=torch.float32)
        states = torch.cat((padding, states), dim=0)
        states = (states - self.state_mean) / self.state_std
        
        # For language and text, adjust the concatenation logic to add padding at the front:
        if(self.variant["withLanguage"] and ("h" in self.variant["LanguageType"] or "f" in self.variant["LanguageType"] or "r" in self.variant["LanguageType"])):
            if("r" in self.variant["LanguageType"]):
                if("h" in self.variant["LanguageType"] and "f" in self.variant["LanguageType"]):  
                    if (self.variant["env_name"]=="metaworld" or self.variant["diversity_threshold"]==1000):
                        key="rhf_embedding"
                    else:
                        key="hf_embedding_"+str(self.variant["diversity_threshold"])
                    language = trajectory[key][starting_point:ending_point].to("cuda")
                elif("h" in self.variant["LanguageType"]):
                    language = trajectory["rh_embedding"][starting_point:ending_point].to("cuda")
                elif("f" in self.variant["LanguageType"]):
                    language = trajectory["rf_embedding"][starting_point:ending_point].to("cuda")
            elif("h" in self.variant["LanguageType"] and "f" in self.variant["LanguageType"]):
                language = trajectory["hf_embedding"][starting_point:ending_point].to("cuda")
            elif("f" in self.variant["LanguageType"]):
                language=trajectory["f_embedding_100"][starting_point:ending_point].to("cuda")
                # language=trajectory["rf_embedding"][starting_point:ending_point].to("cuda")
            elif("h" in self.variant["LanguageType"]):
                language=trajectory["h_embedding_100"][starting_point:ending_point].to("cuda")
                # language=trajectory["rh_embedding"][starting_point:ending_point].to("cuda")
            # Step 1: Generate a random mask
            # Calculate the total number of embeddings
            total_embeddings = language.size(0)
            # ratio=2/3
            ratio=0
            num_to_mask = int(total_embeddings * ratio)
            mask_indices = torch.randperm(total_embeddings)[:num_to_mask].to("cuda")
            mask_language = torch.zeros(total_embeddings, dtype=torch.bool).to("cuda")
            mask_language [mask_indices] = True
            expanded_mask = mask_language.unsqueeze(-1).unsqueeze(-1).expand_as(language).to("cuda")
            masked_language = torch.where(expanded_mask, self.encoded_empty_language.repeat(total_embeddings, 1, 1), language).to("cuda")
            language = masked_language.to('cuda')
            if pad_length>0:
                padding=self.encoded_empty_language.repeat(pad_length,1,1)
                language = torch.cat((padding, language), dim=0).to(device='cuda')
        else:
            with torch.no_grad():
                language=self.encoded_empty_language.repeat(self.segmentLength,1,1).to('cuda')
        time_stamps = torch.arange(starting_point, ending_point, device='cuda').long()
        time_stamps = pad(time_stamps, (pad_length, 0), "constant", 0)
        language=language.squeeze(1)
        return {
            'text': text.to("cuda").unsqueeze(1),
            'state': states,
            'reward': rewards,
            'action': actions,
            'return_to_go': return_to_go,
            'language': language.to("cuda"),
            'time_stamp': time_stamps,
            'mask': mask.to("cuda")
        }
def experiment(variant):
    import os
    torch.multiprocessing.set_start_method('spawn')
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ['CUDA_VISIBLE_DEVICES'] = variant["gpu_id"]
    torch.cuda.empty_cache()
    if(variant["log"]):
        wandb.login()
        wandb_logger = WandbLogger(project="Messenger-MetaWorld-Environment", name=variant["project_name"],log_model=True)
    else:
        wandb_logger=None
    if(variant["env_name"]=="messenger"):
        variant["state_dim"]=18
        variant["act_dim"]=1
        variant["scale"]=100
        variant["target_return"]=200
    
     # Load data
    data=getDataset(variant["data_size"],trajectory_dir = variant["trajectory_dir"],env=variant["env_name"])
    state_mean=torch.load("./messenger_state_mean.pth")
    state_std=torch.load("./messenger_state_std.pth")
    model = MultimodalTransformer(state_mean=state_mean,state_std=state_std,variant=variant).to("cuda")
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    # Create dataset
    dataset = MultimodalDataset(data,state_mean,state_std,variant)

    # Create a DataLoader
    train_loader = DataLoader(dataset, batch_size=variant.get("batch_size"), shuffle=True)

    # Initialize your model
    # model.load_state_dict(torch.load("/data/checkpoint/messenger/original_harder_clear_stage2_rhlang/model_9.pth"))
    # model.load_state_dict(torch.load("/data/checkpoint/messenger/original_harder_clear_stage2_rflang/model_4.pth"))
    # model.load_state_dict(torch.load("/data/checkpoint/messenger/original_harder_clear_stage1_nolang/model_219.pth"))
    # model.load_state_dict(torch.load("/data/checkpoint/messenger/original_harder_clear_stage2_rhflang/model_9.pth"))
    # Create a PyTorch Lightning trainer
    trainer = pl.Trainer(max_epochs=variant.get("max_epoch"), accelerator="gpu", devices=1,logger=wandb_logger,log_every_n_steps=variant.get("log_every_n_steps"))

    # Train the model
    trainer.fit(model, train_loader)
    
    wandb.finish()

def set_seed(seed=42):
    """reproduce"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # 如果使用多GPU
    np.random.seed(seed)
    random.seed(seed)


if __name__ == '__main__':
    set_seed(42)
    env_name="messenger"############################################################################ Must check
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", type=str, default='messenger')
    parser.add_argument("--gpu_id", type=str, default='0')
    parser.add_argument("--log", type=bool, default=True)
    parser.add_argument("--train_tasks", type=list, default=[])
    parser.add_argument("--realLang", type=bool, default=False)
    parser.add_argument("--trajectory_dir", type=str, default='/data/messenger_final_dataset') # folder
    parser.add_argument("--project_name", type=str, default='msgr hf masking original task newLang')
    parser.add_argument("--local", type=str, default='msgr_hf_MDT.pth')
    parser.add_argument("--artifact_name", type=str, default='original_harder_clear_stage1_f_lang_rerun') # folder
    parser.add_argument("--max_epoch", type=int, default=500) 
    parser.add_argument("--learning_rate", "-lr", type=float, default=1e-3)
    parser.add_argument("--T_max", type=float, default=500) 
    parser.add_argument("--eta_min", type=float, default=1e-5)
    parser.add_argument("--withLanguage", type=bool, default=True)
    parser.add_argument("--LanguageType", type=str, default="f") # language
    parser.add_argument("--diversity_threshold", type=int, default=1000) # diversity
    parser.add_argument("--newTask", type=bool, default=False)
    parser.add_argument("--eval_interval", type=int, default=5) 
    parser.add_argument("--num_eval_episodes", type=int, default=50) 
    parser.add_argument("--log_artifact_interval", type=int, default=1)
    parser.add_argument("--data_size", type=int, default=20000) 
    parser.add_argument("--log_every_n_steps", type=int, default=25) 
    parser.add_argument("--batch_size", type=int, default=256) 
    parser.add_argument("--roberta_encoder_len", type=int, default=768)
    parser.add_argument("--embed_dim", type=int, default=256)
    parser.add_argument("--n_layer", type=int, default=5)
    parser.add_argument("--n_head", type=int, default=2)
    parser.add_argument("--activation_function", type=str, default="relu")
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--K", type=int, default=6) # max length of trajectory sent to transformer
    parser.add_argument("--max_ep_len", type=int, default=50) # max length of each episode during evaluation
    parser.add_argument("--state_dim", type=int, default=10)
    parser.add_argument("--act_dim", type=int, default=4)
    parser.add_argument("--weight_decay", "-wd", type=float, default=1e-5)
    parser.add_argument("--log_to_wandb", "-w", type=bool, default=True)
    parser.add_argument("--env", type=str, default="msgr-train-v3")
    parser.add_argument(
        "--mode", type=str, default="normal"
    )  # normal for standard setting, delayed for sparse
    parser.add_argument("--action_category", type=int, default=5)
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()
    experiment(variant=vars(args))
    
    