import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from tqdm import tqdm
import losses
import pandas as pd
from sklearn.model_selection import train_test_split
import wandb
import torch
from transformers import PreTrainedTokenizer, BatchEncoding
from typing import List, Tuple

class SentencePooling(nn.Module):
    def __init__(self, pooling_strategy):
        super(SentencePooling, self).__init__()
        self.pooling_strategy = pooling_strategy

    def forward(self, outputs, attention_mask):
        ##Check about cls token
        token_embeddings = outputs.last_hidden_state[:, :]


        # Apply attention mask to the token embeddings
        masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)

        # Calculate sum of non-padded token embeddings
        sum_embeddings = torch.sum(masked_embeddings, dim=1)

        # Calculate the count of non-padded tokens
        count_tokens = torch.sum(attention_mask, dim=1, keepdim=True)

        # Calculate the mean by dividing the sum by the count
        mean_embeddings = sum_embeddings / count_tokens

        if self.pooling_strategy == "mean":
            pooled_embeddings = mean_embeddings
        elif self.pooling_strategy == "max":
            pooled_embeddings, _ = torch.max(outputs.last_hidden_state[:, 1:], dim=1)
        elif self.pooling_strategy == "concat":
            max_embeddings, _ = torch.max(outputs.last_hidden_state[:, 1:], dim=1)
            pooled_embeddings = torch.cat((mean_embeddings, max_embeddings), dim=1)
        else:
            raise ValueError(f"Invalid pooling strategy: {self.pooling_strategy}")

        return pooled_embeddings


def load_data(file_path):
    df = pd.read_csv(file_path, sep=",", encoding='utf-8')
    df=df.head(512)
    sentences1 = df["sentence1"].tolist()
    sentences2 = df["sentence2"].tolist()
    labels = df["label"].tolist()

    return sentences1, sentences2, labels


def calculate_accuracy(predictions, labels):
    correct_predictions = torch.eq(predictions, labels).sum().item()
    total_predictions = len(predictions)
    accuracy = correct_predictions / total_predictions * 100
    return accuracy


def find_nearest_neighbor(embeddings, labels):
    # Compute pairwise distances/similarities
    distance_matrix = torch.mm(embeddings, torch.transpose(embeddings, 0, 1))

    # Exclude the diagonal (nearest neighbor itself)
    masked_distance_matrix = distance_matrix - torch.diag(distance_matrix.diag())

    # Find second nearest neighbor indices
    second_nearest_indices = torch.argmax(masked_distance_matrix, dim=1)

    # Get second nearest neighbor labels
    second_nearest_labels = labels[second_nearest_indices]

    return second_nearest_labels




def tokenize_and_encode(
    sentences1: List[str],
    sentences2: List[str],
    labels: List[int],
    tokenizer: PreTrainedTokenizer,
    max_length: int = 512,
    padding: str = "max_length",
    truncation: bool = True
) -> Tuple[BatchEncoding, BatchEncoding, torch.Tensor]:
    # Tokenization and Encoding
    encoded1 = tokenizer.batch_encode_plus(
        sentences1,
        padding=padding,
        truncation=truncation,
        max_length=max_length,
        return_tensors="pt"
    )
    encoded2 = tokenizer.batch_encode_plus(
        sentences2,
        padding=padding,
        truncation=truncation,
        max_length=max_length,
        return_tensors="pt"
    )

    input_ids1 = encoded1["input_ids"]
    attention_masks1 = encoded1["attention_mask"]

    input_ids2 = encoded2["input_ids"]
    attention_masks2 = encoded2["attention_mask"]

    labels = torch.tensor(labels)

    return (input_ids1, attention_masks1), (input_ids2, attention_masks2), labels


#### Run as script
if __name__ == "__main__":

    ###init wandb log
    # wandb.init("econabhishek","experiment_sbert_hf")

    model_name = "distilroberta-base"
    pooling_strategy = "mean"  # Options: "mean", "max", "concat"
    margin = 0.5
    batch_size = 64
    num_epochs = 10
    learning_rate = 2e-5

    file_path = "/mnt/data01/github_repos/end-to-end-pipeline/train_set.csv"

    sentences1, sentences2, labels = load_data(file_path)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    train_sentences1, test_sentences1, train_sentences2, test_sentences2, train_labels, test_labels = train_test_split(
        sentences1, sentences2, labels, test_size=0.2, random_state=42
    )
    train_sentences1, eval_sentences1, train_sentences2, eval_sentences2, train_labels, eval_labels = train_test_split(
        train_sentences1, train_sentences2, train_labels, test_size=0.2, random_state=42
    )

    tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")

    # Train-Test-Eval Split
    train_sentences1, test_sentences1, train_sentences2, test_sentences2, train_labels, test_labels = train_test_split(
        sentences1, sentences2, labels, test_size=0.2, random_state=42
    )
    train_sentences1, eval_sentences1, train_sentences2, eval_sentences2, train_labels, eval_labels = train_test_split(
        train_sentences1, train_sentences2, train_labels, test_size=0.2, random_state=42
    )

    # Tokenization and Encoding
    # Tokenization and Encoding
    train_features1, train_features2, train_labels = tokenize_and_encode(
        train_sentences1, train_sentences2, train_labels, tokenizer
    )
    eval_features1, eval_features2, eval_labels = tokenize_and_encode(
        eval_sentences1, eval_sentences2, eval_labels, tokenizer
    )
    test_features1, test_features2, test_labels = tokenize_and_encode(
        test_sentences1, test_sentences2, test_labels, tokenizer
    )
    # # Dataset and Dataloader
    # train_dataset = TensorDataset(input_ids1, attention_masks1, input_ids2, attention_masks2, labels)
    # train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    
    # for idx,data in enumerate(train_dataloader):
    #     print(idx)
    #     print(data)
    #     break

