import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from tqdm import tqdm
from nlp_utils.pytorch_bienc.losses import OnlineContrastiveLoss
import pandas as pd
from sklearn.model_selection import train_test_split
import wandb
import math
import faiss
from typing import List, Tuple
import bitsandbytes as bnb
import numpy as np
from sentence_transformers import SentenceTransformer
import os
import pickle
from  datasets  import Dataset

class SentencePooling(nn.Module):
    def __init__(self, pooling_strategy):
        super(SentencePooling, self).__init__()
        self.pooling_strategy = pooling_strategy

    def forward(self, outputs, attention_mask):
        token_embeddings = outputs[0]

        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()

        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

        sum_mask= input_mask_expanded.sum(1)

        sum_mask = torch.clamp(sum_mask, min=1e-9)



        # Calculate sum of non-padded token embeddings
        if self.pooling_strategy == "mean":
            pooled_embeddings = sum_embeddings / sum_mask
        else:
            raise ValueError(f"Invalid pooling strategy: {self.pooling_strategy}")
        return pooled_embeddings




def tokenize_and_encode(
    sentences1: List[str],
    sentences2: List[str],
    labels: List[int],
    tokenizer: AutoTokenizer,
    max_length: int = 128,
    padding: str = "max_length",
    truncation: bool = True
    
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:

    encoded1 = tokenizer.batch_encode_plus(
        sentences1,
        padding=padding,
        truncation=truncation,
        max_length=max_length,
        return_tensors="pt"
    )
    encoded2 = tokenizer.batch_encode_plus(
        sentences2,
        padding=padding,
        truncation=truncation,
        max_length=max_length,
        return_tensors="pt"
    )

    input_ids1 = encoded1["input_ids"]
    attention_masks1 = encoded1["attention_mask"]

    input_ids2 = encoded2["input_ids"]
    attention_masks2 = encoded2["attention_mask"]


    return (input_ids1, attention_masks1), (input_ids2, attention_masks2), labels




def BinaryClassificationEvaluate(model,tokenizer,dataloader,pooling_module,device,eval_metric="f1",max_seq_len=128,padding="max_length",truncation=True,embedding_dim=768,batch_size=128,sample_eval_loader_steps=None):
    embeddings1,embeddings2,labels=embed_data_loader(model,tokenizer,dataloader,pooling_module,device,max_seq_len=max_seq_len,padding=padding,truncation=truncation,embedding_dim=embedding_dim,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps)

    assert len(embeddings1) == len(embeddings2) == len(labels)
    ####We now want to calculate the cosine similarity between the embeddings
    ###First normalize the embeddings

    scores = 1-nn.functional.cosine_similarity(embeddings1, embeddings2, dim=1)
    ###Now we want to calculate the accuracy and the optimal threshold
    del embeddings1,embeddings2

    if eval_metric=="acc":
        assert len(scores) == len(labels)
        rows = list(zip(scores, labels))

        rows = sorted(rows, key=lambda x: x[0], reverse=False)

        max_acc = 0
        best_threshold = -1

        positive_so_far = 0
        remaining_negatives = sum(labels == 0)

        for i in range(len(rows)-1):
            score, label = rows[i]
            if label == 1:
                positive_so_far += 1
            else:
                remaining_negatives -= 1

            acc = (positive_so_far + remaining_negatives) / len(labels)
            if acc > max_acc:
                max_acc = acc
                best_threshold = (rows[i][0] + rows[i+1][0]) / 2

        print("Max Accuracy: {:.2f}%".format(max_acc * 100))
        print("Best Threshold: {:.4f}".format(best_threshold))

        ###now calculate precision, recall, f1
        output={"acc":max_acc,"threshold":best_threshold}
        return output
    
    elif eval_metric=="f1":
        assert len(scores) == len(labels)

        rows = list(zip(scores, labels))

        rows = sorted(rows, key=lambda x: x[0], reverse=False)

        best_f1 = best_precision = best_recall = 0
        threshold = 0
        nextract = 0
        ncorrect = 0
        total_num_duplicates = sum(labels)

        for i in range(len(rows)-1):
            score, label = rows[i]
            nextract += 1

            if label == 1:
                ncorrect += 1

            if ncorrect > 0:
                precision = ncorrect / nextract
                recall = ncorrect / total_num_duplicates
                f1 = 2 * precision * recall / (precision + recall)
                if f1 > best_f1:
                    best_f1 = f1
                    best_precision = precision
                    best_recall = recall
                    threshold = (rows[i][0] + rows[i + 1][0]) / 2

        output={"f1":best_f1,"precision":best_precision,"recall":best_recall,"threshold":threshold}
    
    else:
        raise ValueError(f"Invalid eval metric: {eval_metric}, only acc and f1 are supported")


    return output





def normalize_tensor_inplace(tensor,p=2,dim=1,keepdim=True):
    l2_norm=torch.norm(tensor,p=2,dim=1,keepdim=True)
    tensor.div_(l2_norm)



def embed_data_loader(model,tokenizer,dataloader,pooling_module,device,max_seq_len=128,padding="max_length",truncation=True,embedding_dim=768,batch_size=128,sample_eval_loader_steps=None):

    len_data=min(len(dataloader.dataset),sample_eval_loader_steps*batch_size) if sample_eval_loader_steps is not None else len(dataloader.dataset)
    print("len_data: {}".format(len_data))

    accum_embedding_1=torch.zeros(len_data,embedding_dim)
    accum_embedding_2=torch.zeros(len_data,embedding_dim)
    labels_tensor=torch.zeros(len_data,dtype=torch.long)
    current_index_1 = 0
    i=0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):

            current_batch_size =len(batch["sentence_1"])

            ##tokenize 
            batch=tokenize_and_encode(batch["sentence_1"],batch["sentence_2"],batch["labels"],tokenizer,max_length=max_seq_len,padding=padding,truncation=truncation)
            features1, features2, labels = batch
            input_ids1, attention_masks1 = features1
            input_ids2, attention_masks2 = features2
            
            ##Send everything to device
            input_ids1=input_ids1.to(device)
            attention_masks1=attention_masks1.to(device)
            input_ids2=input_ids2.to(device)
            attention_masks2=attention_masks2.to(device)
            labels=labels.to(device)


            outputs1 = model(input_ids=input_ids1, attention_mask=attention_masks1)
            outputs2 = model(input_ids=input_ids2, attention_mask=attention_masks2)
            embeddings1 = pooling_module(outputs1, attention_masks1)
            embeddings2 = pooling_module(outputs2, attention_masks2)
            accum_embedding_1[current_index_1:current_index_1 + current_batch_size].copy_(embeddings1)
            accum_embedding_2[current_index_1:current_index_1 + current_batch_size].copy_(embeddings2)
            labels_tensor[current_index_1:current_index_1 + current_batch_size].copy_(labels)
            current_index_1 += current_batch_size
            i=i+1

            if sample_eval_loader_steps is not None:
                if i==sample_eval_loader_steps:
                    break


    ##Normalize in place 

    normalize_tensor_inplace(accum_embedding_1,p=2,dim=1)

    normalize_tensor_inplace(accum_embedding_2,p=2,dim=1)
    labels=labels_tensor.cpu().numpy()

    return accum_embedding_1,accum_embedding_2,labels

def save_model_and_tokenizer(model, tokenizer, output_dir):
    """Save the model and tokenizer to the output directory"""
    ##Create output dir if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    ###IF the model is a dataparallel model, then save the module
    if isinstance(model,nn.DataParallel):
        model.module.save_pretrained(os.path.join(output_dir))
        tokenizer.save_pretrained(os.path.join(output_dir))
        print("Model and tokenizer saved in {}".format(output_dir))
    else:
        model.save_pretrained(os.path.join(output_dir))
        tokenizer.save_pretrained(os.path.join(output_dir))
        print("Model and tokenizer saved in {}".format(output_dir))

def save_model_config(model, output_dir):
    """Save model config"""
    ##Create output dir if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if isinstance(model,nn.DataParallel):
        model.module.config.save_pretrained(os.path.join(output_dir))
        print("Model config saved in {}".format(output_dir))
    else:
        model.config.save_pretrained(os.path.join(output_dir))
        print("Model config saved in {}".format(output_dir))


def add_special_tokens(model,tokenizer,special_tokens):
    ###Add special tokens to the tokenizer
    tokenizer.add_tokens(list(special_tokens.values()),special_tokens=True)
    ###Resize the token embeddings
    model.resize_token_embeddings(len(tokenizer))
    return model,tokenizer

class LazyLoadingDataset(Dataset):
    def __init__(self, data_dict):
        self.sentences_1 = data_dict['sentence_1']
        self.sentences_2 = data_dict['sentence_2']
        self.labels = data_dict['labels']

    def __len__(self):
        return len(self.sentences_1)

    def __getitem__(self, idx):
        sample = {
            'sentence_1': self.sentences_1[idx],
            'sentence_2': self.sentences_2[idx],
            'labels': self.labels[idx]
        }
        print(sample)
        # Your data preprocessing here, e.g., converting sentences to tensors
        return sample

def custom_collate_fn(batch):
    # 'batch' is a list of samples (dictionaries)
    # You can process and collate the data here as needed
    # For example, convert sentences to tensors and stack them
    processed_batch = {
        'sentence_1': torch.stack([sample['sentence_1'] for sample in batch]),
        'sentence_2': torch.stack([sample['sentence_2'] for sample in batch]),
        'labels': torch.stack([sample['labels'] for sample in batch])
    }
    return processed_batch

def train_biencoder_custom(
    train_data: dict = None,
    eval_data: dict = None,
    test_data: dict = None,
    model_name: str = "sentence-transformers/all-mpnet-base-v2",
    pooling_strategy: str = "mean",
    shuffle_train: bool = True,
    loss_function=OnlineContrastiveLoss,
    margin: float = 0.5,
    batch_size: int = 32,
    num_epochs: int = 10,
    learning_rate: float = 2e-5,
    warmup_steps: int = 1000,
    warmup_perc: float = 0.1,
    model_save_path: str = "output",
    wandb_log: bool = False,
    wandb_project: str = None,
    wandb_run_name: str = None,
    inter_eval_steps: int = 10,
    max_seq_len: int= 128,
    eval_metric="f1",
    special_tokens=None,
    sample_eval_loader_steps=None,
    log_epoch_step=True,
    start_epoch_step=0,
):


    # Extract train data as a huddingface dataset
    print("Extracting train data")
    train_dataset=Dataset.from_dict(train_data)
    del train_data
    
    # # Extract eval data
    eval_dataset=Dataset.from_dict(eval_data)
    del eval_data



    # Extract test data
    if test_data is not None:
        test_dataset=Dataset.from_dict(test_data)

    ##Make dataloaders
    print("Making dataloaders")
    train_dataloader=DataLoader(train_dataset,batch_size=batch_size,shuffle=shuffle_train)
    if sample_eval_loader_steps is not None:
        eval_dataloader=DataLoader(eval_dataset,batch_size=batch_size,shuffle=True)
    else:
        eval_dataloader=DataLoader(eval_dataset,batch_size=batch_size,shuffle=False)

    if test_data is not None:
        test_dataloader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2",use_fast=True)

    # Model
    model = AutoModel.from_pretrained(model_name)

    #If special tokens are provided, add them to the tokenizer and resize the token embeddings
    if special_tokens is not None:
        print("Adding Special Tokens")
        model,tokenizer=add_special_tokens(model,tokenizer,special_tokens)



    # Initialize model, optimizer, scheduler, pooling module, and loss function
    ###If num gpus >1, use dataparallel
    ###num gpus
    num_gpus=torch.cuda.device_count()
    if num_gpus>1:
        print("Using {} GPUs".format(num_gpus))
        model = nn.DataParallel(model)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optimizer = bnb.optim.AdamW8bit(model.parameters(),lr=learning_rate) #torch.optim.AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps if warmup_perc is None else math.ceil(len(train_dataloader) * num_epochs * warmup_perc), num_training_steps=total_steps)
    pooling_module = SentencePooling(pooling_strategy)
    loss_function = loss_function(margin=margin)

    # Initialize WandB logging
    if wandb_log:
        wandb.init(project=wandb_project, name=wandb_run_name)

    # Check zero-shot evaluation accuracy   
    eval_output = BinaryClassificationEvaluate(model, tokenizer, eval_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps)
    eval_accuracy,zs_thresh = eval_output[eval_metric],eval_output["threshold"]

    print("Zero Shot Evaluation Accuracy = {:.2f}%".format(eval_accuracy))
    wandb.log({"zs_thresh": zs_thresh})
    wandb.log({"zs_accuracy": eval_accuracy})

    # Variables to track best model and accuracy
    eval_accuracy=0
    best_model = None
    best_accuracy = eval_accuracy if eval_accuracy is not None else 0.0

    ##If log_epoch_step is true, then log the epoch step

    ##Make model dir
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)


    # Training loop
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        model.train()
        total_loss = 0.0
        total_samples = 0

        epoch_step=0
        for batch in tqdm(train_dataloader, desc="Epoch {}".format(epoch + 1)):
            epoch_step+=1
            if epoch_step<start_epoch_step+1:
                continue
            if log_epoch_step:
                ###Log it in a text file
                with open(os.path.join(model_save_path,"epoch_step.txt"),"w") as f:
                    f.write(str(epoch_step))
            ###Eval if epoch_step % inter_eval_steps == 0
            if inter_eval_steps>0 :
                if epoch_step % inter_eval_steps == 0:
                    eval_output = BinaryClassificationEvaluate(model, tokenizer, eval_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps)
                    eval_accuracy,zs_thresh = eval_output[eval_metric],eval_output["threshold"]
                    print(f"Using eval metric: {eval_metric}")
                    print("Interemdiate Eval = {:.2f}%".format(eval_accuracy))
                    wandb.log({"inter_eval/threshold": zs_thresh})
                    wandb.log({"inter_eval/accuracy": eval_accuracy})
                    if eval_accuracy > best_accuracy:
                        best_accuracy = eval_accuracy
                        wandb.log({"inter_eval/best_accuracy_eval": best_accuracy})
                        save_model_and_tokenizer(model, tokenizer, model_save_path)
                        save_model_config(model, model_save_path)
            # batch = [item.to(device) for item in batch]
            ##tokenize 
            batch=tokenize_and_encode(batch["sentence_1"],batch["sentence_2"],batch["labels"],tokenizer,max_length=max_seq_len,padding="max_length",truncation=True)
            features1, features2, labels = batch
            input_ids1, attention_masks1 = features1
            input_ids2, attention_masks2 = features2
            
            ##Send everything to device
            input_ids1=input_ids1.to(device)
            attention_masks1=attention_masks1.to(device)


            optimizer.zero_grad()
            ###First send input_ids1, attention_masks1 to device
            size_input_ids1=input_ids1.size(0)

            outputs1 = model(input_ids=input_ids1, attention_mask=attention_masks1)
            embeddings1 = pooling_module(outputs1, attention_masks1)
            del outputs1, attention_masks1, input_ids1

            ###Then send input_ids2, attention_masks2 to device
            input_ids2=input_ids2.to(device)
            attention_masks2=attention_masks2.to(device)
            outputs2 = model(input_ids=input_ids2, attention_mask=attention_masks2)
            embeddings2 = pooling_module(outputs2,attention_masks2)
            del outputs2, attention_masks2, input_ids2

            ###Now send labels to device
            labels=labels.to(device)
            # Calculate loss
            loss = loss_function(embeddings1, embeddings2, labels)
            ###Log loss to wandb
            wandb.log({"loss": loss.item()})
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += loss.item() * size_input_ids1
            total_samples += size_input_ids1


            # Clear intermediate variables
            del labels, embeddings1, embeddings2

            torch.cuda.empty_cache()

        average_loss = total_loss / total_samples
        print("Epoch {}: Loss = {:.4f}".format(epoch + 1, average_loss))

        # Evaluate on the validation set
        eval_output = BinaryClassificationEvaluate(model, tokenizer, eval_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps)
        eval_accuracy,zs_thresh = eval_output[eval_metric],eval_output["threshold"]
        print("Epoch {}: Evaluation Accuracy = {:.2f}%".format(epoch + 1, eval_accuracy))

        # Check if current model has higher accuracy than the best model
        if eval_accuracy > best_accuracy:
            best_accuracy = eval_accuracy
            save_model_and_tokenizer(model, tokenizer, model_save_path)
            save_model_config(model, model_save_path)
        # Log metrics to WandB
        if wandb_log:
            wandb.log({"epoch": epoch + 1, "train_loss": average_loss, "eval_accuracy": eval_accuracy})
            wandb.log({"eval_thresh": zs_thresh})
        
        ###SAve epoch model
        print("Saving model... for epoch {}".format(epoch+1))
        save_model_and_tokenizer(model, tokenizer, os.path.join(model_save_path+f"_epoch_{epoch+1}"))

    # # Save the best model
    # if best_model is not None:
    #     torch.save(best_model, model_save_path)

    # Evaluate on the test set
    if test_data is not None:
        eval_output = BinaryClassificationEvaluate(model, tokenizer, test_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps)
        test_accuracy,zs_thresh = eval_output[eval_metric],eval_output["threshold"]
        print("Testing Accuracy = {:.2f}%".format(test_accuracy))

        # Log final metrics to WandB
        if wandb_log:
            wandb.log({"test_accuracy": test_accuracy})

    # Explicitly delete variables to conserve memory
    del train_dataset, eval_dataset
    del train_dataloader, eval_dataloader
    del model, optimizer, scheduler, pooling_module, loss_function
    
    if test_data is not None:
        del test_dataset, test_dataloader
        return test_accuracy
    else:
        return None


def sentence_encode(sentence,model,pooling_module,tokenizer,device):
    model.eval()
    
        ###Tokenize
    encoded=tokenizer.encode_plus(sentence,return_tensors="pt",padding="max_length",truncation=True,max_length=400)
    input_ids=encoded["input_ids"]
    attention_mask=encoded["attention_mask"]
    ###Send to device
    input_ids=input_ids.to(device)
    attention_mask=attention_mask.to(device)
        ###Get embeddings
    with torch.no_grad():
        outputs=model(input_ids=input_ids,attention_mask=attention_mask)
    embeddings=pooling_module(outputs,attention_mask)
    ##Normalize
    embeddings=nn.functional.normalize(embeddings,p=2,dim=1)
    return embeddings

def batch_encode(batch,model,pooling_module,tokenizer,device):
    model.eval()
    ###Tokenize
    encoded=tokenizer.batch_encode_plus(batch,return_tensors="pt",truncation=True,padding="max_length")
    input_ids=encoded["input_ids"]
    attention_mask=encoded["attention_mask"]
    ###Send to device
    input_ids=input_ids.to(device)
    attention_mask=attention_mask.to(device)
        ###Get embeddings
    with torch.no_grad():
        outputs=model(input_ids=input_ids,attention_mask=attention_mask)
    embeddings=pooling_module(outputs,attention_mask)
    ##Normalize
    embeddings=nn.functional.normalize(embeddings,p=2,dim=1)
    return embeddings


#### Run as script
if __name__ == "__main__":

    data_path="/mnt/data01/same_story/all_sides_data/wire_clusters/w_headlines/train_set.csv"

    data_path_eval="/mnt/data01/same_story/all_sides_data/wire_clusters/w_headlines/dev_set.csv"

    ###Split the dataframe into train-test and then train-validation
    ##Load
    df_train = pd.read_csv(data_path, sep="\t", encoding='utf-8')
    df_eval = pd.read_csv(data_path_eval, sep="\t", encoding='utf-8')

    ###SSubset df to only keep sentence_1, sentence_2, labels
    df_train=df_train[["sentence_1","sentence_2","labels"]]
    df_eval=df_eval[["sentence_1","sentence_2","labels"]]

    ##Drop none/missings
    df_train=df_train.dropna()
    df_eval=df_eval.dropna()


    ##Convert to dict
    df_train=df_train.to_dict(orient="list")
    df_eval=df_eval.to_dict(orient="list")


    # # ###Start training using the custom train function
    test_acc=train_biencoder_custom(train_data=df_train,eval_data=df_eval,wandb_project="biencoder_pt",wandb_run_name="biencoder_pt_1",
                                    model_name="sentence-transformers/all-mpnet-base-v2",wandb_log=True, warmup_perc=0.5,batch_size=128,
                                    num_epochs=20,margin=0.5)    

