import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
import wandb
import math
import faiss
from typing import List, Tuple
# import bitsandbytes as bnb
##Import adamw
from torch.optim import AdamW
from multiprocessing import Pool
from functools import partial
import numpy as np
import os
import pickle
from  datasets  import Dataset
from torch.utils.data import Dataset, DataLoader, Sampler
import random
from itertools import combinations, product
import random
import datasets


##Set path to nlp_utils
import sys
sys.path.append("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/github_repos/end-to-end-pipeline/")
from nlp_utils.pytorch_bienc.losses import OnlineContrastiveLoss




class HNDataset(Dataset):
    def __init__(self, cluster_df, headlines_df, m=2):
        self.cluster_df = cluster_df
        self.headlines_df = headlines_df
        self.m = m
        
        # Get unique group_ids as classes
        self.classes = np.unique(self.cluster_df['group_id_x'].values)
        
        # Create a mapping from group_id to indices in the cluster dataframe
        self.group_id_indices_clusters = {group_id: indices for group_id, indices in self.cluster_df.groupby('group_id_x').groups.items()}
        # Create a mapping from group_id to indices in the headlines dataframe 
        self.group_id_indices_examples = {group_id: indices for group_id, indices in self.headlines_df.groupby('group_id').groups.items()}

    def __len__(self):
        # Number of  anchors
        return len(self.classes) 
    
    def __getitem__(self, index):
        class_index = self.classes[index]

        # Get the corresponding group_id_y values (neighbors)
        neighbors = self.cluster_df.loc[self.group_id_indices_clusters[class_index], 'group_id_y'].values
        ###Sample m examples for the current class
        sampled_indices_left = np.random.choice(self.group_id_indices_examples[class_index], size=self.m*len(neighbors), replace=True)
        ##Sample m examples for EACH of the neighbors  in the neighbors array. So, for each neighbor, we sample m examples. each neighbour has to exist
        sampled_indices_right=[]
        for neighbor in neighbors:
            sampled_indices_right.extend(np.random.choice(self.group_id_indices_examples[neighbor], size=self.m, replace=True))

        # Append text examples to list_1 and list_2
        list_1= (self.headlines_df['headline'].iloc[sampled_indices_left].tolist())
        list_2= (self.headlines_df['headline'].iloc[sampled_indices_right].tolist())
        class_id_list= ([class_index] * (self.m * len(neighbors)))
        ##Neihbours list would be repeated m times in the same order as neighbours. so 11 22 33
        neighbors_list=([num for num in neighbors for _ in range(self.m)])
        
        # Append labels to labels list 1 if  
        labels = [1 if class_id == neighbor else 0 for class_id, neighbor in zip(class_id_list, neighbors_list)]
      
        return list_1, list_2, labels


def HNCollate(batch):
    list_1_batch, list_2_batch, labels_batch = [], [], []
    for sample in batch:
        list_1, list_2, labels = sample
        list_1_batch.extend(list_1)
        list_2_batch.extend(list_2)
        labels_batch.extend(labels)
    ##Convert labels to tensor
    labels_batch = torch.tensor(labels_batch)
    return {"sentence_1":list_1_batch, "sentence_2": list_2_batch, "labels":labels_batch}

class HNSampler(Sampler):
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __iter__(self):
        return iter(range(len(self.dataset)))
    
    def __len__(self):
        return len(self.dataset)

def year_to_dataloader_nohn(output_folder,year,batch_size,neg_mul):
    years_headlines = pd.read_csv(f"/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/headlines_df/dffull_{year}.csv")  # Your headlines dataframe
    ##Drop na
    dict_data = {"sentence_1": [], "sentence_2": [], "labels": []}

    years_headlines=years_headlines.dropna()
    ##Convert to format list1,list2,labels as a dict {sentence_1:[],sentence_2:[],labels:[]}. Take combinations of 2 headlines. 50% of the time, the labels are 1. 50% of the time, the labels are 0
    ##Positives -> COmbinations within a group_id
    ##Negatives -> Combinations between group_ids

    # Get all group_ids
    group_ids = np.unique(years_headlines['group_id'].values)

    # Enumerate over all group_ids
    for group_id in tqdm(group_ids):
        # Get headlines for the current group
        headlines_group = years_headlines[years_headlines['group_id'] == group_id]['headline'].values

        # Create combinations of headlines within the group
        combos_within = list(combinations(headlines_group, 2))

        # Add the within-group combinations to the data
        dict_data["sentence_1"].extend([combo[0] for combo in combos_within])
        dict_data["sentence_2"].extend([combo[1] for combo in combos_within])
        dict_data["labels"].extend([1] * len(combos_within))

        # For each group, sample neg_mul times the number of combinations between groups
        count=0
        while count<neg_mul*len(combos_within):
            # Randomly select a headline from another group
            group_id2 = np.random.choice((group_ids))
            if group_id2==group_id:
                continue
            headlines_group2 = years_headlines[years_headlines['group_id'] == group_id2]['headline'].values

            # Create combinations of headlines between the two groups
            combos_between = list(product(headlines_group, headlines_group2))

            # Add the between-group combinations to the data
            dict_data["sentence_1"].extend([combo[0] for combo in combos_between])
            dict_data["sentence_2"].extend([combo[1] for combo in combos_between])
            dict_data["labels"].extend([0] * len(combos_between))
            count+=len(combos_between)

    # Convert the data to a dataset and dataloader
    dataset = datasets.Dataset.from_dict(dict_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader



def years_to_dataloader_nohn(output_folder,years,batch_size,neg_mul):
    print("Getting dataloader for years",years)
    dict_data = {"sentence_1": [], "sentence_2": [], "labels": []}

    # Concatenate all years
    years_headlines = pd.concat([pd.read_csv(f"/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/headlines_df/dffull_{year}.csv") for year in years])
    # Drop na
    years_headlines = years_headlines.dropna()

    # Get all group_ids
    group_ids = np.unique(years_headlines['group_id'].values)

    # Enumerate over all group_ids
    for group_id in tqdm(group_ids):
        # Get headlines for the current group
        headlines_group = years_headlines[years_headlines['group_id'] == group_id]['headline'].values

        # Create combinations of headlines within the group
        combos_within = list(combinations(headlines_group, 2))

        # Add the within-group combinations to the data
        dict_data["sentence_1"].extend([combo[0] for combo in combos_within])
        dict_data["sentence_2"].extend([combo[1] for combo in combos_within])
        dict_data["labels"].extend([1] * len(combos_within))

        # For each group, sample neg_mul times the number of combinations between groups
        count=0
        while count<neg_mul*len(combos_within):
            # Randomly select a headline from another group
            group_id2 = np.random.choice((group_ids))
            if group_id2==group_id:
                continue
            headlines_group2 = years_headlines[years_headlines['group_id'] == group_id2]['headline'].values

            # Create combinations of headlines between the two groups
            combos_between = list(product(headlines_group, headlines_group2))

            # Add the between-group combinations to the data
            dict_data["sentence_1"].extend([combo[0] for combo in combos_between])
            dict_data["sentence_2"].extend([combo[1] for combo in combos_between])
            dict_data["labels"].extend([0] * len(combos_between))
            count+=len(combos_between)

    # Convert the data to a dataset and dataloader
    dataset = datasets.Dataset.from_dict(dict_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

        
    



def year_to_dataloader(output_folder,year,m,batch_size,k):
    df_knn_deduped=pd.read_csv(f"{output_folder}/df_hn_{year}.csv")
    headlines_data = pd.read_csv(f"/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/headlines_df/dffull_{year}.csv")  # Your headlines dataframe
    ##Drop na
    df_knn_deduped=df_knn_deduped.dropna()
    headlines_data=headlines_data.dropna()
    dataset = HNDataset(df_knn_deduped.reset_index(drop=True), headlines_data.reset_index(drop=True), m=m)
    sampler = HNSampler(dataset)
    dataloader = DataLoader(dataset, batch_size=batch_size//m//k, sampler=sampler, collate_fn=HNCollate)

    return dataloader     

def years_to_dataloader(output_folder,list_of_years,m,batch_size,k):
    df_knn_dedped_list=[]
    headlines_data_list=[]
    for year in list_of_years:
        df_knn_dedped_list.append(pd.read_csv(f"{output_folder}/df_hn_{year}.csv"))
        headlines_data_list.append(pd.read_csv(f"/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/headlines_df/dffull_{year}.csv"))
        print(f"number of rows in year {year} is {len(df_knn_dedped_list[-1])}")

    df_knn_deduped=pd.concat(df_knn_dedped_list)
    ##Drop na
    df_knn_deduped=df_knn_deduped.dropna()
    headlines_data=pd.concat(headlines_data_list)
    headlines_data=headlines_data.dropna()
    print(f"number of rows in total is {len(df_knn_deduped)}")
    dataset = HNDataset(df_knn_deduped.reset_index(drop=True), headlines_data.reset_index(drop=True), m=m)
    sampler = HNSampler(dataset)
    dataloader = DataLoader(dataset, batch_size=batch_size//m//k, sampler=sampler, collate_fn=HNCollate)
    return dataloader

class SentencePooling(nn.Module):
    def __init__(self, pooling_strategy):
        super(SentencePooling, self).__init__()
        self.pooling_strategy = pooling_strategy

    def forward(self, outputs, attention_mask):
        token_embeddings = outputs[0]

        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()

        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

        sum_mask= input_mask_expanded.sum(1)

        sum_mask = torch.clamp(sum_mask, min=1e-9)



        # Calculate sum of non-padded token embeddings
        if self.pooling_strategy == "mean":
            pooled_embeddings = sum_embeddings / sum_mask
        else:
            raise ValueError(f"Invalid pooling strategy: {self.pooling_strategy}")
        return pooled_embeddings




def tokenize_and_encode(
    sentences1: List[str],
    sentences2: List[str],
    labels: List[int],
    tokenizer: AutoTokenizer,
    max_length: int = 128,
    padding: str = "max_length",
    truncation: bool = True
    
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:

    encoded1 = tokenizer.batch_encode_plus(
        sentences1,
        padding=padding,
        truncation=truncation,
        max_length=max_length,
        return_tensors="pt"
    )
    encoded2 = tokenizer.batch_encode_plus(
        sentences2,
        padding=padding,
        truncation=truncation,
        max_length=max_length,
        return_tensors="pt"
    )

    input_ids1 = encoded1["input_ids"]
    attention_masks1 = encoded1["attention_mask"]

    input_ids2 = encoded2["input_ids"]
    attention_masks2 = encoded2["attention_mask"]


    return (input_ids1, attention_masks1), (input_ids2, attention_masks2), labels




def BinaryClassificationEvaluate(model,tokenizer,dataloader,pooling_module,device,eval_metric="f1",max_seq_len=128,padding="max_length",truncation=True,embedding_dim=768,batch_size=128,sample_eval_loader_steps=None,hn=True):
    embeddings1,embeddings2,labels=embed_data_loader(model,tokenizer,dataloader,pooling_module,device,max_seq_len=max_seq_len,padding=padding,truncation=truncation,embedding_dim=embedding_dim,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps,hn=hn)

    assert len(embeddings1) == len(embeddings2) == len(labels)
    ####We now want to calculate the cosine similarity between the embeddings
    ###First normalize the embeddings

    scores = 1-nn.functional.cosine_similarity(embeddings1, embeddings2, dim=1)
    ###Now we want to calculate the accuracy and the optimal threshold
    del embeddings1,embeddings2

    if eval_metric=="acc":
        assert len(scores) == len(labels)
        rows = list(zip(scores, labels))

        rows = sorted(rows, key=lambda x: x[0], reverse=False)

        max_acc = 0
        best_threshold = -1

        positive_so_far = 0
        remaining_negatives = sum(labels == 0)

        for i in range(len(rows)-1):
            score, label = rows[i]
            if label == 1:
                positive_so_far += 1
            else:
                remaining_negatives -= 1

            acc = (positive_so_far + remaining_negatives) / len(labels)
            if acc > max_acc:
                max_acc = acc
                best_threshold = (rows[i][0] + rows[i+1][0]) / 2

        print("Max Accuracy: {:.2f}%".format(max_acc * 100))
        print("Best Threshold: {:.4f}".format(best_threshold))

        ###now calculate precision, recall, f1
        output={"acc":max_acc,"threshold":best_threshold}
        return output
    
    elif eval_metric=="f1" or eval_metric=="precision" or eval_metric=="recall":
        assert len(scores) == len(labels)

        rows = list(zip(scores, labels))

        rows = sorted(rows, key=lambda x: x[0], reverse=False)

        best_f1 = best_precision = best_recall = 0
        threshold = 0
        nextract = 0
        ncorrect = 0
        total_num_duplicates = sum(labels)

        for i in range(len(rows)-1):
            score, label = rows[i]
            nextract += 1

            if label == 1:
                ncorrect += 1

            if ncorrect > 0:
                precision = ncorrect / nextract
                recall = ncorrect / total_num_duplicates
                f1 = 2 * precision * recall / (precision + recall)
                if f1 > best_f1:
                    best_f1 = f1
                    best_precision = precision
                    best_recall = recall
                    threshold = (rows[i][0] + rows[i + 1][0]) / 2

        output={"f1":best_f1,"precision":best_precision,"recall":best_recall,"threshold":threshold}
    


    else:
        raise ValueError(f"Invalid eval_metric: {eval_metric}")

    return output





def normalize_tensor_inplace(tensor,p=2,dim=1,keepdim=True):
    l2_norm=torch.norm(tensor,p=p,dim=dim,keepdim=keepdim)
    tensor.div_(l2_norm)



def embed_data_loader(model,tokenizer,dataloader,pooling_module,device,max_seq_len=128,padding="max_length",truncation=True,embedding_dim=768,batch_size=128,m=2,sample_eval_loader_steps=None,hn=True):

    if hn:
        len_data=min(len(dataloader.dataset)*m,sample_eval_loader_steps*batch_size) if sample_eval_loader_steps is not None else len(dataloader.dataset)*m
    else:
        len_data=min(len(dataloader.dataset),sample_eval_loader_steps*batch_size) if sample_eval_loader_steps is not None else len(dataloader.dataset)

    accum_embedding_1=torch.zeros(len_data,embedding_dim)
    accum_embedding_2=torch.zeros(len_data,embedding_dim)
    labels_tensor=torch.zeros(len_data,dtype=torch.long)
    current_index_1 = 0
    i=0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):


            current_batch_size =len(batch["sentence_1"])

            ##tokenize 
            batch=tokenize_and_encode(batch["sentence_1"],batch["sentence_2"],batch["labels"],tokenizer,max_length=max_seq_len,padding=padding,truncation=truncation)
            features1, features2, labels = batch
            input_ids1, attention_masks1 = features1
            input_ids2, attention_masks2 = features2
            
            ##Send everything to device
            input_ids1=input_ids1.to(device)
            attention_masks1=attention_masks1.to(device)
            input_ids2=input_ids2.to(device)
            attention_masks2=attention_masks2.to(device)
            labels=labels.to(device)


            outputs1 = model(input_ids=input_ids1, attention_mask=attention_masks1)
            outputs2 = model(input_ids=input_ids2, attention_mask=attention_masks2)
            embeddings1 = pooling_module(outputs1, attention_masks1)
            embeddings2 = pooling_module(outputs2, attention_masks2)

            accum_embedding_1[current_index_1:current_index_1 + current_batch_size].copy_(embeddings1)
            accum_embedding_2[current_index_1:current_index_1 + current_batch_size].copy_(embeddings2)
            labels_tensor[current_index_1:current_index_1 + current_batch_size].copy_(labels)
            current_index_1 += current_batch_size
            i=i+1

            if sample_eval_loader_steps is not None:
                if i==sample_eval_loader_steps:
                    break


    ##Normalize in place 

    normalize_tensor_inplace(accum_embedding_1,p=2,dim=1)

    normalize_tensor_inplace(accum_embedding_2,p=2,dim=1)
    labels=labels_tensor.cpu().numpy()

    return accum_embedding_1,accum_embedding_2,labels

def save_model_and_tokenizer(model, tokenizer, output_dir):
    """Save the model and tokenizer to the output directory"""
    ##Create output dir if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    ###IF the model is a dataparallel model, then save the module
    if isinstance(model,nn.DataParallel):
        model.module.save_pretrained(os.path.join(output_dir))
        tokenizer.save_pretrained(os.path.join(output_dir))
        print("Model and tokenizer saved in {}".format(output_dir))
    else:
        model.save_pretrained(os.path.join(output_dir))
        tokenizer.save_pretrained(os.path.join(output_dir))
        print("Model and tokenizer saved in {}".format(output_dir))

def save_model_config(model, output_dir):
    """Save model config"""
    ##Create output dir if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if isinstance(model,nn.DataParallel):
        model.module.config.save_pretrained(os.path.join(output_dir))
        print("Model config saved in {}".format(output_dir))
    else:
        model.config.save_pretrained(os.path.join(output_dir))
        print("Model config saved in {}".format(output_dir))


def add_special_tokens(model,tokenizer,special_tokens):
    ###Add special tokens to the tokenizer
    tokenizer.add_tokens(list(special_tokens.values()),special_tokens=True)
    ###Resize the token embeddings
    model.resize_token_embeddings(len(tokenizer))
    return model,tokenizer


def train_biencoder_custom(
    cluster_df_dir: str,
    text_df_dir: str,
    train_years: dict = None,
    val_years: dict = None,
    test_years: dict = None,
    train_hn:bool = True,
    val_hn:bool = False,
    test_hn:bool = False,
    model_name: str = "sentence-transformers/all-mpnet-base-v2",
    pooling_strategy: str = "mean",
    loss_function=OnlineContrastiveLoss,
    margin: float = 0.5,
    batch_size: int = 32,
    num_epochs: int = 10,
    learning_rate: float = 2e-5,
    warmup_steps: int = 1000,
    warmup_perc: float = 0.1,
    model_save_path: str = "output",
    wandb_log: bool = False,
    wandb_project: str = None,
    wandb_run_name: str = None,
    inter_eval_steps: int = 10,
    max_seq_len: int= 128,
    eval_metric="f1",
    special_tokens=None,
    sample_eval_loader_steps=None
):



    ##Make dataloaders
    print("Making dataloaders")
    if val_hn:
        eval_dataloader = years_to_dataloader("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/hn_data",val_years,2,batch_size,4)
    else:
        eval_dataloader = years_to_dataloader_nohn("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/data",val_years,batch_size,1)
    if test_years is not None:
        if test_hn:
            test_dataloader= years_to_dataloader("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/hn_data",test_years,2,batch_size,4)
        else:
            test_dataloader= years_to_dataloader_nohn("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/data",test_years,batch_size,1)
    

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)

    # Model
    model = AutoModel.from_pretrained(model_name)

    #If special tokens are provided, add them to the tokenizer and resize the token embeddings
    if special_tokens is not None:
        print("Adding Special Tokens")
        model,tokenizer=add_special_tokens(model,tokenizer,special_tokens)



    # Initialize model, optimizer, scheduler, pooling module, and loss function
    ###If num gpus >1, use dataparallel
    ###num gpus
    num_gpus=torch.cuda.device_count()
    if num_gpus>1:
        print("Using {} GPUs".format(num_gpus))
        model = nn.DataParallel(model)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate) #torch.optim.AdamW(model.parameters(), lr=learning_rate)
    total_steps = 10000 #len(train_dataloader) * num_epochs
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps if warmup_perc is None else math.ceil(len(train_dataloader) * num_epochs * warmup_perc), num_training_steps=total_steps)
    
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    pooling_module = SentencePooling(pooling_strategy)
    loss_function = loss_function(margin=margin)

    # Initialize WandB logging
    if wandb_log:
        wandb.init(project=wandb_project, name=wandb_run_name)

    # Check zero-shot evaluation accuracy   
    if val_hn:
        eval_output = BinaryClassificationEvaluate(model, tokenizer, eval_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps,hn=True)
    else:
        eval_output = BinaryClassificationEvaluate(model, tokenizer, eval_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps,hn=False)
    eval_accuracy,zs_thresh = eval_output[eval_metric],eval_output["threshold"]

    print("Zero Shot Evaluation Accuracy = {:.2f}%".format(eval_accuracy))
    wandb.log({"zs_thresh": zs_thresh})
    wandb.log({"zs_accuracy": eval_accuracy})

    # Variables to track best model and accuracy
    eval_accuracy=0
    best_model = None
    best_accuracy = eval_accuracy if eval_accuracy is not None else 0.0

    # Training loop
    for epoch in range(num_epochs):
        random.shuffle(train_years)

        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        model.train()
        total_loss = 0.0
        total_samples = 0

        epoch_step=0
        for year in train_years:
            print("Training on year {}".format(year))
            if train_hn:
                train_dataloader = year_to_dataloader("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/hn_data",year,m=2,batch_size=batch_size,k=4)
            else:
                train_dataloader=year_to_dataloader_nohn("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/headlines/data",year,batch_size,1)
            for batch in tqdm(train_dataloader, desc="Epoch {}".format(epoch + 1)):
                epoch_step+=1
                ###Eval if epoch_step % inter_eval_steps == 0
                if inter_eval_steps>0 :
                    if epoch_step % inter_eval_steps == 0:
                        eval_output = BinaryClassificationEvaluate(model, tokenizer, eval_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps,hn=val_hn)
                        eval_accuracy,zs_thresh = eval_output[eval_metric],eval_output["threshold"]
                        print(f"Using eval metric: {eval_metric}")
                        print("Interemdiate Eval = {:.2f}%".format(eval_accuracy))
                        wandb.log({"inter_eval/threshold": zs_thresh})
                        wandb.log({"inter_eval/accuracy": eval_accuracy})
                        if eval_accuracy > best_accuracy:
                            best_accuracy = eval_accuracy
                            wandb.log({"inter_eval/best_accuracy_eval": best_accuracy})
                            save_model_and_tokenizer(model, tokenizer, model_save_path)
                            save_model_config(model, model_save_path)
                # batch = [item.to(device) for item in batch]
                ##tokenize 
                batch=tokenize_and_encode(batch["sentence_1"],batch["sentence_2"],batch["labels"],tokenizer,max_length=max_seq_len,padding="max_length",truncation=True)
                features1, features2, labels = batch
                input_ids1, attention_masks1 = features1
                input_ids2, attention_masks2 = features2
                
                ##Send everything to device
                input_ids1=input_ids1.to(device)
                attention_masks1=attention_masks1.to(device)


                optimizer.zero_grad()
                ###First send input_ids1, attention_masks1 to device
                size_input_ids1=input_ids1.size(0)

                outputs1 = model(input_ids=input_ids1, attention_mask=attention_masks1)
                embeddings1 = pooling_module(outputs1, attention_masks1)
                del outputs1, attention_masks1, input_ids1

                ###Then send input_ids2, attention_masks2 to device
                input_ids2=input_ids2.to(device)
                attention_masks2=attention_masks2.to(device)
                outputs2 = model(input_ids=input_ids2, attention_mask=attention_masks2)
                embeddings2 = pooling_module(outputs2,attention_masks2)
                del outputs2, attention_masks2, input_ids2

                ###Now send labels to device
                labels=labels.to(device)
                # Calculate loss
                loss = loss_function(embeddings1, embeddings2, labels)
                ###Log loss to wandb
                wandb.log({"loss": loss.item()})
                loss.backward()
                optimizer.step()
                scheduler.step()

                total_loss += loss.item() * size_input_ids1
                total_samples += size_input_ids1

                # Clear intermediate variables
                del labels, embeddings1, embeddings2

                torch.cuda.empty_cache()
            ###save model after each year
            save_model_and_tokenizer(model, tokenizer, f"{model_save_path}/year_{year}_model_epoch_{epoch}.pt")
        average_loss = total_loss / total_samples
        print("Epoch {}: Loss = {:.4f}".format(epoch + 1, average_loss))

        # Evaluate on the validation set
        eval_output = BinaryClassificationEvaluate(model, tokenizer, eval_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps,hn=val_hn)
        eval_accuracy,zs_thresh = eval_output[eval_metric],eval_output["threshold"]
        print("Epoch {}: Evaluation Accuracy = {:.2f}%".format(epoch + 1, eval_accuracy))

        # Check if current model has higher accuracy than the best model
        if eval_accuracy > best_accuracy:
            best_accuracy = eval_accuracy
            save_model_and_tokenizer(model, tokenizer, model_save_path)
            save_model_config(model, model_save_path)
        # Log metrics to WandB
        if wandb_log:
            wandb.log({"epoch": epoch + 1, "train_loss": average_loss, "eval_accuracy": eval_accuracy})
            wandb.log({"eval_thresh": zs_thresh})
        
        ###SAve epoch model
        print("Saving model... for epoch {}".format(epoch+1))
        save_model_and_tokenizer(model, tokenizer, os.path.join(model_save_path+f"_epoch_{epoch+1}"))

    # # Save the best model
    # if best_model is not None:
    #     torch.save(best_model, model_save_path)

    # Evaluate on the test set
    if test_years is not None:
        eval_output = BinaryClassificationEvaluate(model, tokenizer, test_dataloader, pooling_module, device,eval_metric=eval_metric,max_seq_len=max_seq_len,padding="max_length",truncation=True,embedding_dim=768,batch_size=batch_size,sample_eval_loader_steps=sample_eval_loader_steps,hn=test_hn)
        test_accuracy,zs_thresh = eval_output[eval_metric],eval_output["threshold"]
        print("Testing Accuracy = {:.2f}%".format(test_accuracy))

        # Log final metrics to WandB
        if wandb_log:
            wandb.log({"test_accuracy": test_accuracy})

    # Explicitly delete variables to conserve memory
    del train_dataloader, eval_dataloader
    del model, optimizer, scheduler, pooling_module, loss_function
    
    if test_years is not None:
        del test_dataloader
        return test_accuracy
    else:
        return None


def sentence_encode(sentence,model,pooling_module,tokenizer,device):
    model.eval()
    
        ###Tokenize
    encoded=tokenizer.encode_plus(sentence,return_tensors="pt",padding="max_length",truncation=True,max_length=400)
    input_ids=encoded["input_ids"]
    attention_mask=encoded["attention_mask"]
    ###Send to device
    input_ids=input_ids.to(device)
    attention_mask=attention_mask.to(device)
        ###Get embeddings
    with torch.no_grad():
        outputs=model(input_ids=input_ids,attention_mask=attention_mask)
    embeddings=pooling_module(outputs,attention_mask)
    ##Normalize
    embeddings=nn.functional.normalize(embeddings,p=2,dim=1)
    return embeddings

def batch_encode(batch,model,pooling_module,tokenizer,device):
    model.eval()
    ###Tokenize
    encoded=tokenizer.batch_encode_plus(batch,return_tensors="pt",truncation=True,padding="max_length")
    input_ids=encoded["input_ids"]
    attention_mask=encoded["attention_mask"]
    ###Send to device
    input_ids=input_ids.to(device)
    attention_mask=attention_mask.to(device)
        ###Get embeddings
    with torch.no_grad():
        outputs=model(input_ids=input_ids,attention_mask=attention_mask)
    embeddings=pooling_module(outputs,attention_mask)
    ##Normalize
    embeddings=nn.functional.normalize(embeddings,p=2,dim=1)
    return embeddings


#### Run as script
if __name__ == "__main__":

    ###Sample train years , test, val - years - 1929-1980
    all_years=list(range(1929,1990))
    print("All years: {}".format(all_years))

    train_years,test_years = train_test_split(all_years, test_size=0.2, shuffle=True, random_state=42)
    test_years,val_years = train_test_split(test_years, test_size=0.5, shuffle=True, random_state=42)

    ###save splits as txt files
    with open("train_years.txt","w") as f:
        for year in train_years:
            f.write(str(year)+"\n")
    with open("test_years.txt","w") as f:
        for year in test_years:
            f.write(str(year)+"\n")
    with open("val_years.txt","w") as f:
        for year in val_years:
            f.write(str(year)+"\n")



    # # ###Start training using the custom train function
    test_acc=train_biencoder_custom("","",train_years=train_years,val_years=val_years,test_years=test_years,
                                    train_hn=False,test_hn=False,val_hn=False,
                                    wandb_project="headlines",
                                    wandb_run_name="all-years-splits-headlines_na_highwarmuplr",
                                    model_name="sentence-transformers/all-mpnet-base-v2",wandb_log=True, 
                                    warmup_steps=30000,batch_size=704,model_save_path="headlines_trained_models/all-years-splits-headlines_na_highwarmuplr",
                                            inter_eval_steps=256,
                                            sample_eval_loader_steps=40,
                                    num_epochs=2,margin=0.2,learning_rate=2e-6,eval_metric="f1")    

    # test_acc=train_biencoder_custom("","",train_years=[1942],val_years=[1943],test_years=[1944],
    #                                 train_hn=False,test_hn=False,val_hn=False,
    #                                 wandb_project="headlines",
    #                                 wandb_run_name="all-years-splits-headlines_na_highwarmuplr",
    #                                 model_name="sentence-transformers/all-mpnet-base-v2",wandb_log=True, 
    #                                 warmup_steps=30000,batch_size=704,model_save_path="headlines_trained_models/all-years-splits-headlines_na_highwarmuplr",
    #                                         inter_eval_steps=256,
    #                                         sample_eval_loader_steps=40,
    #                                 num_epochs=2,margin=0.2,learning_rate=2e-6,eval_metric="f1")    
