
from collections import defaultdict
import pickle
import torch
from helper_probing import tokenize, f1_score, accuracy, precision, recall, cluster, generate_key_file
from helper_deving import tokenize_utterances, forward_ab, tokenize_wtd_utterances
from probing_prediction import predict_causal_counterpart
import random
from tqdm import tqdm
import os
from modeling_probing import Causal_Intervention_Scorer
from generate_gold_map_wtd import generate_gold_map
from delitoolkit.delidata import DeliData
import pickle
import json 
import pandas as pd
import numpy as np
import heapq
from coval.coval.conll.reader import get_coref_infos
from coval.coval.eval.evaluator import evaluate_documents as evaluate
from coval.coval.eval.evaluator import muc, b_cubed, ceafe, lea
 

from collections import defaultdict

def read(key, response):
    return get_coref_infos('%s' % key, '%s' % response,
            False, False, True)


def generate_pairs_for_train_eval_relaxed(gold_map, split):

    '''
    Generates all antecedent pairs (relaxing W) and binary labels of casual and probing interventions within a window W using the gold labels
    generated with "get_gold_map_cldeaned".
    Outputs pairs, labels and causal and probing label for each train/eval pair
     
    '''
    group_to_interventions = defaultdict(list)
    pairs_labels_dict = {}
    pair_sample = []
    split_intervention_ids = sorted([m_id for m_id, m in gold_map.items() if m['split'] == split])

    for intervention_id in split_intervention_ids:
        group_id = gold_map[intervention_id]['group_id']
        group_to_interventions[group_id].append(intervention_id)
    intervention_pairs = []    
    intervention_labels = []
    causal_probing_label = { }
    probing_label = {}
    intervention_pairs_labels = {}
    c =0
    #group_to_interventions = list(group_to_interventions.items())[0:10]
    #group_to_interventions = dict(group_to_interventions)
    for interventions in group_to_interventions.values() :
        list_interventions = list(interventions)
        for i in range(len(list_interventions)):
            for j in range(i + 1):
                if i != j:
                    intervention_pairs.append((list_interventions[i], list_interventions[j])) # add the pair label and the causal label 

                    intervention_labels.append(int(gold_map[list_interventions[i]]['gold_cluster'] ==gold_map[list_interventions[j]]['gold_cluster']) )

    for intervention in split_intervention_ids:
        if len(intervention.split("_")) > 1:
            causal_probing_label[intervention] = 'probing'
        else:
            causal_probing_label[intervention] = 'causal'
    print(f" {split} final pairs", len(intervention_pairs), len(intervention_labels), len(causal_probing_label))
    return intervention_pairs, intervention_labels, causal_probing_label
 

def generate_pairs_for_train_eval(gold_map, utterance_sequence_map, split, previous_window = None):
    
    '''
    Generates all antecedent pairs and binary labels of casual and probing interventions within a window W using the gold labels
    generated with "get_gold_map_cldeaned".
    Outputs pairs, labels and causal and probing label for each train/eval pair
     
    '''
   
    
#     utterance_sequence_map = utterance_seq_dict
    group_to_interventions = defaultdict(list)
    pairs_labels_dict = {}
    pair_sample = []
    split_intervention_ids = sorted([m_id for m_id, m in gold_map.items() if m['split'] == split])
    zero_list = []
    for intervention_id in split_intervention_ids:
        group_id = gold_map[intervention_id]['group_id']
        group_to_interventions[group_id].append(intervention_id)
    intervention_pairs = []    
    intervention_labels = []
    causal_probing_label = { }
    probing_label = {}
    intervention_pairs_labels = {}
    
    c =0
    #group_to_interventions = list(group_to_interventions.items())[0:10]
    #group_to_interventions = dict(group_to_interventions)
    for interventions in group_to_interventions.values() :
        list_interventions = list(interventions)
        for i in range(len(list_interventions)): #mid and pid
            for j in range(i + 1):
                if i != j:
                    #find diff 
                    d1 = int(utterance_sequence_map[list_interventions[i]]['utterance_id'])
                    d2 = int(utterance_sequence_map[list_interventions[j]]['utterance_id'])
                    diff = abs(d2-d1)
                    if diff < previous_window and diff==0:
                        zero_list.append((list_interventions[i], list_interventions[j],int(gold_map[list_interventions[i]]['gold_cluster'] ==gold_map[list_interventions[j]]['gold_cluster']) ))
                    
                    if diff <= previous_window:
                     
                        intervention_pairs.append((list_interventions[i], list_interventions[j])) # add the pair label and the causal label 

                        intervention_labels.append(int(gold_map[list_interventions[i]]['gold_cluster'] ==gold_map[list_interventions[j]]['gold_cluster']) )
                    
    for intervention in split_intervention_ids:
        if len(intervention.split("_")) > 1:
            causal_probing_label[intervention] = 'probing'
        else:
            causal_probing_label[intervention] = 'causal'
    print(f" {split} final pairs", len(intervention_pairs), len(intervention_labels), len(causal_probing_label))
    return intervention_pairs, intervention_labels, causal_probing_label, zero_list
 

def remove_cc_and_pp_pairs(pairs, labels, causal_probing_label_map):
    '''
    Generates stats for CC and PP pairs, remove CC and PP pairs for our second baseline
    '''
    label = 'causal'
    probing = 'probing'
    causal_only_pos = []
    probing_only_pos = []
    causal_only_neg= []
    probing_only_neg = []
    original_pair_dict = dict(zip(pairs, labels))
    for (x, y), z in zip(pairs, labels):
       
        if (causal_probing_label_map[x] == label) and (causal_probing_label_map[y] == label):
            if z ==0 :
                causal_only_neg.append((x, y, z))
            else:
                causal_only_pos.append((x, y, z))
        if (causal_probing_label_map[x] == probing) and (causal_probing_label_map[y] == probing):
            if z ==0:
                probing_only_neg.append((x,y, z))
            else:
                probing_only_pos.append((x,y, z))
    print("POS CC and PP", len(causal_only_pos), len(probing_only_pos) )
    print("NEG CC and PP", len(causal_only_neg), len(probing_only_neg) )
    causal_only_neg = { (t[0], t[1]): t[2] for t in causal_only_neg }
    probing_only_neg = { (t[0], t[1]): t[2] for t in probing_only_neg }
    original_pair_dict = {x:y for x, y in original_pair_dict.items() if x not in causal_only_neg.keys() and probing_only_neg.keys() }

    print("Remaining POS and NEG pairs", sum([x for x in original_pair_dict.values() if x ==1]), len(original_pair_dict) - sum([x for x in original_pair_dict.values() if x ==1]))
    
    #create the labels
    
    pairs = list(original_pair_dict.keys())
    labels = list(original_pair_dict.values())
    
    return pairs, labels




def get_probing_causal_counterpart_clusters_pruned(split,dev_pairs, dev_scores_ab, dev_scores_ba, causal_scores, \
        probing_scores, gold_map_dev,causal_probing_label_map, working_folder, k):
    '''
    Generates clusters or deliberation chains frorm pairwise scores after pruning the minimum scores or keeping the top-k.
    Sends them as pairs to avoid additional pairing and pairwise scoring to form the clusters

    '''   
        
    dev_score_map = {}
    pruned_pairs_containing_causal = []
    dev_pairs, dev_scores_ab = retrieve_causal_containing_pairs(dev_pairs, dev_scores_ab, causal_probing_label_map)
    top_causal_scores, pruned_pairs, pairwise_scores  = top_k_paired_elements(dev_pairs, causal_scores, dev_scores_ab, k = k, mode='max')
    dev_pairs = pruned_pairs
    dev_scores_ab = pairwise_scores
 
    dev_scores_ba = pairwise_scores
    for b, ab, ba in zip(dev_pairs, dev_scores_ab, dev_scores_ba):
        dev_score_map[tuple(b)] = (float(ab), float(ba))
    
    curr_mentions = sorted(gold_map_dev.keys())
    curr_gold_cluster_map = [(men, gold_map_dev[men]['gold_cluster']) for men in curr_mentions]
    gold_key_file = working_folder + f'/probing_scores/probing_gold_{split}.keyfile'
    
    generate_key_file(curr_gold_cluster_map, 'evt', working_folder, gold_key_file)
    
    pairwise_scores = []
    for p in dev_pairs:
        if tuple(p) in dev_score_map:
            pairwise_scores.append(np.mean(dev_score_map[p]))
    
    
    mid2cluster = cluster(curr_mentions, dev_pairs, pairwise_scores, threshold=0.5) #checking if theoretically a perfect clustering can be achieved
    system_key_file = working_folder + f'/probing_scores/probing_system_{split}.keyfile'
    generate_key_file(mid2cluster.items(), 'evt', working_folder, system_key_file)
    doc = read(gold_key_file, system_key_file)
    mr, mp, mf = np.round(np.round(evaluate(doc, muc), 3) * 100, 1)
    br, bp, bf = np.round(np.round(evaluate(doc, b_cubed), 3) * 100, 1)
    cr, cp, cf = np.round(np.round(evaluate(doc, ceafe), 3) * 100, 1)
    lr, lp, lf = np.round(np.round(evaluate(doc, lea), 3) * 100, 1)

    conf = np.round((mf + bf + cf) / 3, 1)
    print(working_folder, split)
    final_frame = [mr, mp, mf,br, bp, bf,cr, cp, cf,  lr, lp, lf,conf ]
    result_string = f'&& {mr}  & {mp} & {mf} && {br} & {bp} & {bf} && {cr} & {cp} & {cf} && {lr} & {lp} & {lf} && {conf} \\'

    print(result_string)
    
    return conf, result_string, final_frame



def retrieve_causal_containing_pairs(pairs, pairwise_scores, causal_probing_label_map ):
    '''
    Return the pairs and corresponding scores that contain at least one causal intervention
    '''
    label = 'causal'
    probing = 'probing'
   
    pairs_scores = []
#     original_pair_scores_dict = {}
    unpruned_dict = dict(zip(pairs, pairwise_scores))
    #probing_only = {}
    #print("unpruned dict after zip", len(unpruned_dict))
    for (x, y), z in zip(pairs, pairwise_scores):
        if (causal_probing_label_map[x] == label) or (causal_probing_label_map[y] == label):
            pairs_scores.append((x, y, z))
   
      
    causal_only = {(t[0], t[1]): t[2] for t in pairs_scores}
  
    unpruned_dict = {x:y for x, y in unpruned_dict.items() if x in causal_only.keys()}
    #create the labels
    
    pairs = list(unpruned_dict.keys())
    scores = list(unpruned_dict.values())
    
    return pairs, scores



def top_k_paired_elements(pairs, causal_scores, pairwise_scores, k, mode='max'):
    """
    Returns the top or bottom k pairs based on their causal scores, along with the pruned pairwise scores

    """
    # Combine pairs with scores using zip and create a list of (score, pair, additional_score)
    paired_scores = list(zip(causal_scores, pairs, pairwise_scores))
    
    # Determine whether to find the largest or smallest elements based on mode
    if mode == 'max':
        top_k_elements = heapq.nlargest(k, paired_scores, key=lambda x: x[0])
    elif mode == 'min':
        top_k_elements = heapq.nsmallest(k, paired_scores, key=lambda x: x[0])
    else:
        raise ValueError("Mode should be either 'max' or 'min'")
    
    # Extract scores, pairs, and additional scores into separate lists
    top_causal_scores = [elem[0] for elem in top_k_elements]
    top_pairs = [elem[1] for elem in top_k_elements]
    pairwise_scores = [elem[2] for elem in top_k_elements]
    
    return top_causal_scores, top_pairs, pairwise_scores


def get_probing_causal_counterpart_clusters(split,dev_pairs, dev_scores_ab, dev_scores_ba, gold_map_dev, working_folder ):

    #split = 'dev'
    split =split
    dev_score_map = {}
    for b, ab, ba in zip(dev_pairs, dev_scores_ab, dev_scores_ba):
        dev_score_map[tuple(b)] = (float(ab), float(ba))

    curr_mentions = sorted(gold_map_dev.keys())
    curr_gold_cluster_map = [(men, gold_map_dev[men]['gold_cluster']) for men in curr_mentions]
    gold_key_file = working_folder + f'/probing_scores/probing_gold_{split}_3.keyfile'
    
    generate_key_file(curr_gold_cluster_map, 'evt', working_folder, gold_key_file)
    
    pairwise_scores = []
    for p in dev_pairs:
        if tuple(p) in dev_score_map:
            pairwise_scores.append(np.mean(dev_score_map[p]))
    
    
    mid2cluster = cluster(curr_mentions, dev_pairs, pairwise_scores, threshold=0.5) #checking if theoretically a perfect clustering can be achieved
    system_key_file = working_folder + f'/probing_scores/probing_system_{split}_3.keyfile'
    generate_key_file(mid2cluster.items(), 'evt', working_folder, system_key_file)
    doc = read(gold_key_file, system_key_file)
    mr, mp, mf = np.round(np.round(evaluate(doc, muc), 3) * 100, 1)
    br, bp, bf = np.round(np.round(evaluate(doc, b_cubed), 3) * 100, 1)
    cr, cp, cf = np.round(np.round(evaluate(doc, ceafe), 3) * 100, 1)
    lr, lp, lf = np.round(np.round(evaluate(doc, lea), 3) * 100, 1)

    conf = np.round((mf + bf + cf) / 3, 1)
    print(working_folder, split)
    final_frame = [mr, mp, mf,br, bp, bf,cr, cp, cf,  lr, lp, lf,conf ]
    result_string = f'&& {mr}  & {mp} & {mf} && {br} & {bp} & {bf} && {cr} & {cp} & {cf} && {lr} & {lp} & {lf} && {conf} \\'

    print(result_string)
    return conf, result_string, final_frame



def generate_custom_sequence(ranges_steps):
    """
    Generates a custom sequence of numbers with specified ranges and step sizes.
    
    """
    all_numbers = []
    for start, stop, step in ranges_steps:
        # Create range for each segment and extend the list
        all_numbers.extend(range(start, stop + 1, step))
    
    # Remove duplicates and sort the list
    final_numbers = sorted(set(all_numbers))
    return final_numbers


def find_max_conf_details(data_list):
    # Extract conf values and find the maximum
    conf_values = [item[4] for item in data_list]  # Assuming 'conf' is always the 5th element
    max_conf = max(conf_values)
    index_of_max_conf = conf_values.index(max_conf)

    # Extract and return the required details from the sublist with the highest conf
    max_conf_details = data_list[index_of_max_conf]
    return {
        'split': max_conf_details[0],
        'sampling': max_conf_details[1],
        'n': max_conf_details[2],
        'max_conf': max_conf,
        'index': index_of_max_conf
    }

def get_flat_list(data):
    flattened_conf = []
    for element in data:
        if isinstance(element, list):
            # Flatten further if the element itself is a list containing sublists
            for subelement in element:
                if isinstance(subelement, list):
                    flattened_conf.extend(subelement)  # Extend to add elements of the sublist
                else:
                    flattened_conf.append(subelement)
        else:
            flattened_conf.append(element)
    return flattened_conf

 


def process_clustering_result(dataframe):

    # Define column names
    columns = ['Epoch', 'Metrics']
     # Convert the list of lists into a DataFrame
    df = pd.DataFrame(dataframe, columns=columns)

    # Separate the 'Metrics' column into individual columns
    df[['MUC R', 'MUC P', 'MUC F1', 'B3 R', 'B3 P', 'B3 F1', 'Ceafe R', 'Ceafe P', 'Ceafe F1', 'LEA R', 'LEA P', 'LEA F1', 'CoNLL F1']] = pd.DataFrame(df['Metrics'].tolist(), index=df.index)

    # Drop the original 'Metrics' column
    df.drop(columns=['Metrics'], inplace=True)

    # Set 'Epoch' column as index
    df.set_index('Split', inplace=True)

    return df


def process_result_with_top_k(nested_result):
    # Define the fixed columns as per the requirement
    columns = ["Split", "Epoch", "Top_K", 'MUC R', 'MUC P', 'MUC F1', 'B3 R', 'B3 P', 'B3 F1', 'Ceafe R', 'Ceafe P', 'Ceafe F1', 'LEA R', 'LEA P', 'LEA F1', 'CoNLL F1']
    
    # Initialize an empty list to hold the processed data
    data = []

    # Process each item in the nested result
    for item in nested_result:
        identifiers = item[:3]  # These are the first four elements: Split, Sampling, Epoch, Top_K
        metrics = item[3]       # The metrics are expected to be in the fifth position

        # Ensure that metrics are a list and has the right number of elements
        if not isinstance(metrics, list):
            metrics = [metrics]
        
        # Check if the number of metrics matches the expected number of metric columns
        expected_metric_count = len(columns) - 4  # Total columns minus the number of identifiers
        if len(metrics) < expected_metric_count:
            # Pad the metrics with None or an appropriate value if they are too few
            metrics += [None] * (expected_metric_count - len(metrics))
        elif len(metrics) > expected_metric_count:
            # Trim the metrics if there are too many
            metrics = metrics[:expected_metric_count]

        # Concatenate identifiers and metrics
        full_row = identifiers + metrics
        data.append(full_row)

    # Create DataFrame with specified columns
    df = pd.DataFrame(data, columns=columns)
    return df

def process_result(nested_result):
 
    data = []
 
    for item in nested_result:
   
        row = item[:3]  # This gets 'dev', 2, 0 (or 1 for the second item)
        metrics = item[3]  # This gets the list of metrics
        full_row = row + metrics  # Concatenates the list elements with the metrics
        data.append(full_row)
        
        # Define columns as specified
    columns = ["Split","Epoch", 'MUC R', 'MUC P', 'MUC F1', 'B3 R', 'B3 P', 'B3 F1', 'Ceafe R', 'Ceafe P', 'Ceafe F1', 'LEA R', 'LEA P', 'LEA F1', 'CoNLL F1']

    # Create DataFrame
    df = pd.DataFrame(data, columns=columns)
    return df

def process_result_with_top_k_ablations(nested_result):
    # Define the fixed columns as per the requirement
    columns = ["Window", "Top_K", 'MUC R', 'MUC P', 'MUC F1', 'B3 R', 'B3 P', 'B3 F1', 'Ceafe R', 'Ceafe P', 'Ceafe F1', 'LEA R', 'LEA P', 'LEA F1', 'CoNLL F1']
    
    # Initialize an empty list to hold the processed data
    data = []

    # Process each item in the nested result
    for item in nested_result:
        identifiers = item[:2]  # These are the first four elements: Split, Sampling, Epoch, Top_K
        metrics = item[2]       # The metrics are expected to be in the fifth position

        # Ensure that metrics are a list and has the right number of elements
        if not isinstance(metrics, list):
            metrics = [metrics]
        
        # Check if the number of metrics matches the expected number of metric columns
        expected_metric_count = len(columns) - 2  # Total columns minus the number of identifiers
        if len(metrics) < expected_metric_count:
            # Pad the metrics with None or an appropriate value if they are too few
            metrics += [None] * (expected_metric_count - len(metrics))
        elif len(metrics) > expected_metric_count:
            # Trim the metrics if there are too many
            metrics = metrics[:expected_metric_count]

        # Concatenate identifiers and metrics
        full_row = identifiers + metrics
        data.append(full_row)

    # Create DataFrame with specified columns
    df = pd.DataFrame(data, columns=columns)
    return df