import pickle
import torch
from helper_probing import tokenize, tokenize_utterances, forward_ab, f1_score, \
accuracy, precision, recall, cluster, generate_key_file, get_bert_embedding, get_cosine_similarity, \
calculate_fuzzy_score, calculate_iou
from generate_gold_map import generate_gold_map
from probing_prediction import predict_causal_counterpart
import random
from tqdm import tqdm
import os
from models import CrossEncoder
from delitoolkit.delidata import DeliData
import pickle
import json 
import pandas as pd
import numpy as np
from coval.coval.conll.reader import get_coref_infos
from coval.coval.eval.evaluator import evaluate_documents as evaluate
from coval.coval.eval.evaluator import muc, b_cubed, ceafe, lea
from sklearn.metrics.pairwise import cosine_similarity
import sys

from sklearn.metrics.pairwise import cosine_similarity
from transformers import BertModel, BertTokenizer
import pandas as pd
from fuzzywuzzy import fuzz

# Load the BERT model and tokenizer
cosine_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(cosine_model_name)
cosine_model = BertModel.from_pretrained(cosine_model_name)
import numpy as np

from collections import defaultdict




def generate_pairs_for_train_eval(gold_map, split):
    '''
    Generates all antecedent pairs and binary labels of casual and probing interventions using the gold labels
    generated with "get_gold_map_cldeaned".
    Outputs pairs, labels and causal and probing label for each train/eval pair
     
    '''
    group_to_interventions = defaultdict(list)
    pairs_labels_dict = {}
    pair_sample = []
    split_intervention_ids = sorted([m_id for m_id, m in gold_map.items() if m['split'] == split])

    for intervention_id in split_intervention_ids:
        group_id = gold_map[intervention_id]['group_id']
        group_to_interventions[group_id].append(intervention_id)
    intervention_pairs = []    
    intervention_labels = []
    causal_probing_label = { }
    probing_label = {}
    intervention_pairs_labels = {}
    c =0
    #group_to_interventions = list(group_to_interventions.items())[0:10]
    #group_to_interventions = dict(group_to_interventions)
    for interventions in group_to_interventions.values() :
        list_interventions = list(interventions)
        for i in range(len(list_interventions)):
            for j in range(i + 1):
                if i != j:W
                    intervention_pairs.append((list_interventions[i], list_interventions[j])) # add the pair label and the causal label 

                    intervention_labels.append(int(gold_map[list_interventions[i]]['gold_cluster'] ==gold_map[list_interventions[j]]['gold_cluster']) )

    for intervention in split_intervention_ids:
        if len(intervention.split("_")) > 1:
            causal_probing_label[intervention] = 'probing'
        else:
            causal_probing_label[intervention] = 'causal'
    print(f" {split} final pairs", len(intervention_pairs), len(intervention_labels), len(causal_probing_label))
    return intervention_pairs, intervention_labels, causal_probing_label

def read(key, response):
    return get_coref_infos('%s' % key, '%s' % response,
            False, False, True)

def get_gold_map(test_dict, probing_map, document_map):
    
    m_id_occurrences = defaultdict(list)
    gold_cluster_map = {}
    m_id_to_gold_cluster_map = {}
    final_gold_map_new = {}
    # Iterate through the gold_cluster_map dictionary

    for x, info in test_dict.items():
        m_id_occurrences[x[1]].append(x[0])

    # Find primary keys (x) where the 'm_id' value occurs more than once
    duplicate_x_keys = [x_values for m_id, x_values in m_id_occurrences.items() if len(x_values) > 1]

    for  m_id, x_values in m_id_occurrences.items():

        if len(x_values) > 1:
            duplicate_x_keys = x_values
            min_elements = []
            min_element = None
            min_digit = float('inf') 

            # Iterate through each sublist in the data list
            for item in duplicate_x_keys:
                # Split the element by underscores
                parts = item.split('_')
                # Extract the digit after the underscore
                digit = int(parts[-1])
                # Check if the digit is smaller than the current minimum
                if digit < min_digit:
                    #print("min digit", min_digit, m_id, item)
                    min_digit = digit
                    #print("min digit assign", min_digit, m_id, item)
                    min_element = item

                # Append the minimum element to the result list
                    min_elements.append(min_element)
                    #print(min_element)
                    m_id_to_gold_cluster_map[m_id] = min_element
          
    for index, (x,y) in enumerate(test_dict.items()):
     
        final_gold_map_new[x[0]] = {'m_id':x[1], 'gold_cluster': x[0], 'group_id': document_map[x[1]]['group_id']}
        if x[1] in m_id_to_gold_cluster_map.keys():
            

           
            final_gold_map_new[x[1]] = {'m_id':x[1], 'gold_cluster': m_id_to_gold_cluster_map[x[1]], 'group_id': document_map[x[1]]['group_id']}
        else:
            

            final_gold_map_new[x[1]] = {'m_id':x[1], 'gold_cluster': x[0], 'group_id': document_map[x[1]]['group_id']}  
    return final_gold_map_new



def resample_split(test_dict, gold_map): 

    total_correct_gold = []
    total_correct_negative = []
    new_test_pairs_dict = {}
    for x, y in test_dict.items():
        if y==1:

           
            gold_probing = gold_map[x[0]]['gold_cluster']
            gold_cc = gold_map[x[1]]['gold_cluster']
            
            if gold_probing == gold_cc:
                
                total_correct_gold.append(1)
                new_test_pairs_dict[x] = y
            else:
                new_test_pairs_dict[x] = 0
        else:
            gold_probing = gold_map[x[0]]['gold_cluster']
            gold_cc = gold_map[x[1]]['gold_cluster']
            if gold_probing != gold_cc:
                total_correct_negative.append(1)
                new_test_pairs_dict[x] = y
            else:
                new_test_pairs_dict[x] = 1
    return new_test_pairs_dict
            

def generate_split_pairs_labels(dataset, causal_counterpart_map, document_map, gold = False):

    train_pairs = []
    train_labels = []
    dev_pairs = []
    dev_labels  = []
    test_pairs = []
    test_labels = []
    pos_pairs = pd.read_csv(f"{dataset}/positive_samples.csv")
    neg_pairs = pd.read_csv(f"{dataset}/negative_samples.csv")
    pos_pairs['probingQuestionID'].tolist()

    positive_pairs = list(zip(pos_pairs['probingQuestionID'], pos_pairs['message_id']))
    negative_pairs = list(zip(neg_pairs['probingQuestionID'], neg_pairs['message_id']))
    
    
    
    def remove_initial_message_from_pairs(positive_pairs,negative_pairs):
    
        pos_training_samples = []
        neg_training_samples = []
        exceptions_pairs = []
        negative_pairs_exception = []
        pos_bad_indices = []
        neg_bad_indices = []

        for index, (x, y) in enumerate(positive_pairs):
            if y in document_map.keys():
                continue
            elif  y == '-1':
                pos_bad_indices.append(index)
            else:
                pos_bad_indices.append(index)
        for index, (x, y) in enumerate(negative_pairs):
            if y in document_map.keys():
                continue
            else:
                neg_bad_indices.append(index)
        return pos_bad_indices, neg_bad_indices
    pos_bad_indices, neg_bad_indices = remove_initial_message_from_pairs(positive_pairs,negative_pairs)

   
    positive_pairs_cleaned = [(x,y) for index, (x,y) in enumerate(positive_pairs) if index not in pos_bad_indices]
    negative_pairs_cleaned = [(x,y) for index, (x,y) in enumerate(negative_pairs) if index not in neg_bad_indices]
    
    if not gold: 
        for index, (x, y) in enumerate(positive_pairs_cleaned):
            if causal_counterpart_map[x]['set'] =="Train":
                train_pairs.append((x,y))
                train_labels.append(1)
            elif causal_counterpart_map[x]['set'] =="Dev":
                dev_pairs.append((x,y))
                dev_labels.append(1)

            elif causal_counterpart_map[x]['set'] =="Test":
                test_pairs.append((x,y))
                test_labels.append(1)


        for index, (x, y) in enumerate(negative_pairs_cleaned): # get the negative pairwise labels as zero 
            if causal_counterpart_map[x]['set'] =="Train":
                train_pairs.append((x,y))
                train_labels.append(0)
            elif causal_counterpart_map[x]['set'] =="Dev":
                dev_pairs.append((x,y))
                dev_labels.append(0)

            elif causal_counterpart_map[x]['set'] =="Test":
                test_pairs.append((x,y))
                test_labels.append(0)
    else:
        for index, (x, y) in enumerate(positive_pairs_cleaned):
            if causal_counterpart_map[x]['set'] =="Train":
                train_pairs.append((x,y))
                train_labels.append(1)
            elif causal_counterpart_map[x]['set'] =="Dev":
                dev_pairs.append((x,y))
                dev_labels.append(1)

            elif causal_counterpart_map[x]['set'] =="Test":
                test_pairs.append((x,y))
                test_labels.append(1)


        for index, (x, y) in enumerate(negative_pairs_cleaned): # get the negative pairwise labels as zero 
            if causal_counterpart_map[x]['set'] =="Train":
                train_pairs.append((x,y))
                train_labels.append(0)
    return train_pairs, train_labels, dev_pairs, dev_labels, test_pairs, test_labels


def combine_non_probing_with_probing_map(dataset, wtd = False):
    #print('DOING FOR WTD ', wtd)
    causal_counterpart_gpt_responses_file = f'{dataset}/final.pkl'
    with open(causal_counterpart_gpt_responses_file, "rb") as f:
        generated_causal_counterpart_map  = pickle.load(f)

  
    utterance_map_train = {m_id: m for m_id, m in generated_causal_counterpart_map.items() if m['set'] == 'Train'}
    utterance_map_dev = {m_id: m for m_id, m in generated_causal_counterpart_map.items() if m['set'] == 'Dev'}
    utterance_map_test = {m_id: m for m_id, m in generated_causal_counterpart_map.items() if m['set'] == 'Test'}

    group_id_train = set([y['group_id'] for x, y in utterance_map_train.items()])
    group_id_dev = set([y['group_id'] for x, y in utterance_map_dev.items()])
    group_id_test = set([y['group_id'] for x, y in utterance_map_test.items()])

    delidata_corpus = DeliData()
    groups = list(delidata_corpus.corpus.keys())
    probing_questions = []
    document_map = {}
    wtd = False
    for group, messages in delidata_corpus.corpus.items():
        for m in messages:
      
            if m['message_id'] != '-1':
                document_map[m['message_id']] = {
                                'group_id': m['group_id'],
                                'message_id': m['message_id'],
                                'message_type': m['message_type'],
                                'origin': m['origin'],
                                'original_text': m['original_text'],
                                'clean_text': m['clean_text'],
                                'annotation_type': m['annotation_type'],
                                'annotation_target': m['annotation_target'],
                                'annotation_additional': m['annotation_additional'],
                                'team_performance': m['team_performance'],
                                'performance_change': m['performance_change'],
                                'sol_tracker_message': m['sol_tracker_message'],
                                'sol_tracker_all': m['sol_tracker_all']


     
                            }
                
    if(wtd):
        df = pd.read_csv(f'{dataset}/final.csv')

        for _, row in df.iterrows():
            document_map[row['message_id']] = {
                'group_id': row['group_id'],
                'message_id': row['message_id'],
               # 'message_type': row['message_type'],
                'origin': row['origin'],
                'original_text': row['original_text'],
                #'clean_text': row['clean_text'],
                'annotation_type': row['annotation_type'],
                # 'prev_utterance_history': prev_history_map[key]  # Uncomment and define prev_history_map if needed
            }

    for m_id, values in document_map.items():
        if values['group_id'] in group_id_train:
            document_map[m_id]['set'] ='Train'
        elif values['group_id'] in group_id_dev:
            document_map[m_id]['set'] ='Dev'
        elif values['group_id'] in group_id_test:
            document_map[m_id]['set'] ='Test' 
    return generated_causal_counterpart_map, document_map



def get_probing_causal_counterpart_clusters_non_trainable(split,test_pairs, test_scores_ab, test_scores_ba, gold_map_test, working_folder, threshold = .5):

    split = 'test'
    test_score_map = {}
    for b, ab, ba in zip(test_pairs, test_scores_ab, test_scores_ba):
        test_score_map[tuple(b)] = (float(ab), float(ba))
    
    print("test score map",len(test_score_map) )
    #print("dev score map",len(dev_score_map) )
    
    curr_mentions = sorted(gold_map_test.keys())
    curr_gold_cluster_map = [(men, gold_map_test[men]['gold_cluster']) for men in curr_mentions]
    gold_key_file = working_folder + f'/probing_gold_{split}.keyfile'
    
    generate_key_file(curr_gold_cluster_map, 'evt', working_folder, gold_key_file)
    
    pairwise_scores = []
    for p in test_pairs:
        if tuple(p) in test_score_map:
            pairwise_scores.append(np.mean(test_score_map[p]))
    
    
    mid2cluster = cluster(curr_mentions, test_pairs, pairwise_scores, threshold=.5) #checking if theoretically a perfect clustering can be achieved
    system_key_file = working_folder + f'/probing_system_{split}.keyfile'
    generate_key_file(mid2cluster.items(), 'evt', working_folder, system_key_file)
    doc = read(gold_key_file, system_key_file)
    mr, mp, mf = np.round(np.round(evaluate(doc, muc), 3) * 100, 1)
    br, bp, bf = np.round(np.round(evaluate(doc, b_cubed), 3) * 100, 1)
    cr, cp, cf = np.round(np.round(evaluate(doc, ceafe), 3) * 100, 1)
    lr, lp, lf = np.round(np.round(evaluate(doc, lea), 3) * 100, 1)

    conf = np.round((mf + bf + cf) / 3, 1)
    print(working_folder, split)
    final_frame = [mr, mp, mf,br, bp, bf,cr, cp, cf,  lr, lp, lf,conf ]
    result_string = f'&& {mr}  & {mp} & {mf} && {br} & {bp} & {bf} && {cr} & {cp} & {cf} && {lr} & {lp} & {lf} && {conf} \\'

    print(result_string)
    return conf, result_string, final_frame
def process_clustering_result(dataframe):

    # Define column names
    columns = ['Epoch', 'Metrics']
     # Convert the list of lists into a DataFrame
    df = pd.DataFrame(dataframe, columns=columns)

    # Separate the 'Metrics' column into individual columns
    df[['MUC R', 'MUC P', 'MUC F1', 'B3 R', 'B3 P', 'B3 F1', 'Ceafe R', 'Ceafe P', 'Ceafe F1', 'LEA R', 'LEA P', 'LEA F1', 'CoNLL F1']] = pd.DataFrame(df['Metrics'].tolist(), index=df.index)

    # Drop the original 'Metrics' column
    df.drop(columns=['Metrics'], inplace=True)

    # Set 'Epoch' column as index
    df.set_index('Epoch', inplace=True)

    return df

def accuracy(predicted_labels, true_labels):
    # Calculate the number of correct predictions
    correct_predictions = sum(1 for pred, true in zip(predicted_labels, true_labels) if pred == true)
    # Calculate accuracy
    return correct_predictions / len(predicted_labels) if len(predicted_labels) > 0 else 0

def precision(predicted_labels, true_labels):
    true_positives = sum(1 for pred, true in zip(predicted_labels, true_labels) if pred == true == 1)
    predicted_positives = sum(predicted_labels)
    return true_positives / predicted_positives if predicted_positives else 0

def recall(predicted_labels, true_labels):
    true_positives = sum(1 for pred, true in zip(predicted_labels, true_labels) if pred == true == 1)
    actual_positives = sum(true_labels)
    return true_positives / actual_positives if actual_positives else 0
 
def f1_score(predicted_labels, true_labels):
    prec = precision(predicted_labels, true_labels)
    rec = recall(predicted_labels, true_labels)
    if prec + rec == 0:
        return 0  # Prevent division by zero
    return 2 * (prec * rec) / (prec + rec)


def train_dpos(dataset, model_name=None, trainable = True):

    final_coref = []
    final_f = []
    final_conf = []
    final_exp_results = []
    wtd = False
    # dataset = 'deli_data'
    if(dataset == 'wtd_dataset'):
            wtd = True
    print('This is the data', dataset)
    dataset_folder = f'./datasets/{dataset}/'
    device = torch.device('cuda:1')
    device_ids = list(range(1))
    device_ids = [1]
    #get the maps 
    probing_map, document_map = combine_non_probing_with_probing_map(dataset = dataset, wtd = wtd)
 
    test_cosine_scores_ab = []
    dev_cosine_scores_ab = []
    test_token_overlap_scores_ab = []
    dev_token_overlap_scores_ab = []
    test_iou_overlap_scores_ab = []
    dev_iou_overlap_scores_ab = []
   

    gold_map_train, gold_map_dev, gold_map_test = generate_gold_map(dataset = dataset)
    
  
    train_pairs, train_labels, causal_probing_label_train = generate_pairs_for_train_eval(gold_map_train, split='train')
  
    dev_pairs, dev_labels, causal_probing_label_dev = generate_pairs_for_train_eval(gold_map_dev, split = 'dev')
   
    test_pairs, test_labels, causal_probing_label_test = generate_pairs_for_train_eval(gold_map_test, split = 'test')
   
    print("pairs after gold map creation", len(train_pairs), len(train_labels),len (dev_pairs), len(dev_labels), len(test_pairs), len(test_labels)) # 2948 training positive pairs
    print("positive samples in train", sum([x for x in train_labels if x ==1]))
    print("neg samples in train", len(train_labels) -  sum([x for x in train_labels if x ==1]))
    print("positive samples in dev", sum([x for x in dev_labels if x ==1]))
    print("neg samples in dev", len(dev_labels) - sum([x for x in dev_labels if x ==1]))
    print("positive samples in test", sum([x for x in test_labels if x ==1]))
    print("neg samples in test", len(test_labels) -sum([x for x in test_labels if x ==1]))

    trainable = False
    if(trainable == False):
       # Get the threshold for the non-trainable baselines using the dev pairs
        for probing_id, message_id in tqdm(dev_pairs,desc="Generating thresholds for dev"): #For each dev pair (ProbingQuestionID, messageID)
            try: 
                probing_text = probing_map[probing_id]['probing_utterance'] #if probing_id in probing_map else 'Probing ID not found' #Probing Question
            except KeyError:
                probing_text = document_map[probing_id]['original_text']
            
            try:
                casual_text = document_map[message_id]['original_text'] #if message_id in document_map else 'Message ID not found' #Causal counterpart
            except KeyError:
                casual_text = probing_map[message_id]['probing_utterance']
            
            # Get embeddings
            probing_embedding = get_bert_embedding(probing_text)
            casual_embedding = get_bert_embedding(casual_text)
            #Get token overlap scores 
            overlap_score = calculate_fuzzy_score(probing_text, casual_text)
            dev_token_overlap_scores_ab.append(overlap_score)

            #get IOU scores
            iou_score = calculate_iou(probing_text, casual_text, wtd= True)
            dev_iou_overlap_scores_ab.append(iou_score)
            
            if probing_embedding is not None and casual_embedding is not None:
                sim = get_cosine_similarity(probing_embedding, casual_embedding) #Get cosine similarity
                dev_cosine_scores_ab.append(sim)

        overlap_threshold = (sum(dev_token_overlap_scores_ab) / len(dev_token_overlap_scores_ab))/100
        
        
        dev_overlap_predictions = [score > overlap_threshold for score in dev_token_overlap_scores_ab]

        cosine_threhsold = sum(dev_cosine_scores_ab)/ len(dev_cosine_scores_ab)
       
        dev_cosine_predictions = [score > cosine_threhsold for score in dev_cosine_scores_ab]
        iou_threshold = (sum(dev_iou_overlap_scores_ab) / len(dev_iou_overlap_scores_ab))
        
        print('This is the cosine threshold ', cosine_threhsold)
        print('\nThis is the overlap threshold ', overlap_threshold)
        print('\nThis is the iou threshold ', iou_threshold)
        
        for probing_id, message_id in tqdm(test_pairs, desc="Generating scores for test"):
            try: 
                probing_text = probing_map[probing_id]['probing_utterance'] #if probing_id in probing_map else 'Probing ID not found' #Probing Question
            except KeyError:
                probing_text = document_map[probing_id]['original_text']
            
            try:
                casual_text = document_map[message_id]['original_text'] #if message_id in document_map else 'Message ID not found' #Causal counterpart
            except KeyError:
                casual_text = probing_map[message_id]['probing_utterance']
            
            # Get embeddings
            probing_embedding = get_bert_embedding(probing_text)
            casual_embedding = get_bert_embedding(casual_text)
            # Get token overlap 
            overlap_score = (calculate_fuzzy_score(probing_text, casual_text))/100
            test_token_overlap_scores_ab.append(overlap_score)
            test_iou = calculate_iou(probing_text,casual_text, wtd= True)
            test_iou_overlap_scores_ab.append(test_iou)
            
            if probing_embedding is not None and casual_embedding is not None:
                sim = get_cosine_similarity(probing_embedding, casual_embedding)
                test_cosine_scores_ab.append(sim)
        
        
        test_token_overlap_predictions = [score > overlap_threshold for score in test_token_overlap_scores_ab]
        
        test_cosine_predictions = [score > cosine_threhsold for score in test_cosine_scores_ab]

        test_iou_predictions = [score > iou_threshold for score in test_iou_overlap_scores_ab]
     
        conf, final_scores, final_frame = get_probing_causal_counterpart_clusters_non_trainable("test",test_pairs, test_cosine_scores_ab, test_cosine_scores_ab, gold_map_test, dataset_folder, threshold= cosine_threhsold )        
        conf, final_scores, final_frame = get_probing_causal_counterpart_clusters_non_trainable("test",test_pairs, test_token_overlap_scores_ab, test_token_overlap_scores_ab, gold_map_test, dataset_folder, threshold= overlap_threshold) 
        conf, final_scores, final_frame = get_probing_causal_counterpart_clusters_non_trainable("test",test_pairs, test_iou_overlap_scores_ab, test_iou_overlap_scores_ab, gold_map_test, dataset_folder, threshold= iou_threshold) 
        
    print(len(test_pairs))
    pickle.dump(test_pairs, open('deli_test_pairs.pkl', 'wb'))    
    pickle.dump(test_cosine_scores_ab, open('deli_test_cosine_scores_ab.pkl', 'wb'))    
    pickle.dump(test_token_overlap_scores_ab, open('deli_test_token_overlap_scores_ab.pkl', 'wb'))    
    pickle.dump(test_iou_overlap_scores_ab, open('deli_test_iou_overlap_scores_ab.pkl', 'wb'))
   
    return test_pairs, test_cosine_scores_ab, test_token_overlap_scores_ab, test_iou_overlap_scores_ab, test_cosine_predictions, test_token_overlap_predictions, test_iou_predictions, dev_cosine_scores_ab, dev_token_overlap_scores_ab, dev_cosine_predictions, dev_overlap_predictions
       

if __name__ == '__main__':
#     if len(sys.argv) < 2: 
#         print("Usage: python probing_training_clustering.py deli_data/wtd_dataset") 
#         sys.exit(1) 
#     dataset = sys.argv[1]
    #trainable = sys.argv[2]
    dataset = 'deli_data'
    trainable = False
    gold_map_test, test_cosine_scores_ab, test_token_overlap_scores_ab, test_iou_overlap_scores_ab, test_cosine_predictions, test_token_overlap_predictions, test_iou_predictions, dev_cosine_scores_ab, dev_token_overlap_scores_ab, dev_cosine_predictions, dev_overlap_predictions = train_dpos(dataset, model_name='allenai/longformer-base-4096', trainable = trainable)
    
   