import pickle
import torch
from helper_probing import tokenize, f1_score, accuracy, precision, recall, cluster, generate_key_file
from helper_testing import tokenize_utterances, forward_ab
from probing_prediction import predict_causal_counterpart
import random
from tqdm import tqdm
import os
from modeling_probing import Causal_Intervention_Scorer
from generate_gold_map import generate_gold_map
from delitoolkit.delidata import DeliData
import pickle
import json 
import pandas as pd
import numpy as np
from coval.coval.conll.reader import get_coref_infos
from coval.coval.eval.evaluator import evaluate_documents as evaluate
from coval.coval.eval.evaluator import muc, b_cubed, ceafe, lea
from collections import defaultdict
from training_utils import *
import heapq

def get_joint_model_predictions(dataset = 'deli_data', window = True):
    
    device = torch.device('cuda:0')
    device_ids = list(range(1))

    dataset = dataset
    dataset_folder = f'./datasets/{dataset}/'
    split = "test"
    test__folder = dataset_folder + f'./'
    scorer_folder = test__folder + f'./' 
    linear_weights_path = scorer_folder + "/linear.chkpt"
    bert_path = scorer_folder + '/bert'
    probing_weights_path = scorer_folder + "/causal_intevention_layer.chkpt"
    causal_weights_path= scorer_folder +  "probing_intevention_layer.chkpt"


    linear_weights = torch.load(linear_weights_path)
    probing_weights = torch.load(probing_weights_path)
    causal_weights = torch.load(causal_weights_path)

    scorer_module = Causal_Intervention_Scorer(is_training=False, model_name=bert_path, long=True,
                                      linear_weights=linear_weights, joint_model = True, probing_weights = probing_weights,causal_weights = causal_weights).to(device)
    parallel_model = torch.nn.DataParallel(scorer_module, device_ids=device_ids)
    parallel_model.module.to(device)

    tokenizer = parallel_model.module.tokenizer

    #get pairs
    _, _, gold_map, probing_map, document_map = generate_gold_map(dataset = dataset)

def get_cluster_scores():

        # top_k_window = list(range(22, 70, 1))
        utterance_seq_dict_file = f'{dataset_folder}/deli_with_utterance_2.pkl' #new utterance map 
        with open(utterance_seq_dict_file, "rb") as f:
            utterance_sequence_map  = pickle.load(f)


        final_coref_frame = []
        test_scores = []
        val_scores = []
        test_causal_scores = []
        test_probing_scores  = []
        final_coref_results = []
        final_coref_frame = []
        final_conf = []

        final_coref_results = []
        final_coref_frame =  []
        final_conf = []
        looped_conllf1 = []
        cluster_scores_probing_causal_pair = []
        top_k_window = list(range(10, 60, 5))
        # top_k_window = [1]
        for window in top_k_window:
            test_pairs, test_labels, causal_probing_label_test, zero_test = generate_pairs_for_train_eval(gold_map_test,utterance_sequence_map, split = 'test', previous_window = window)
            test_ab, _ = tokenize_utterances(tokenizer, test_pairs, probing_map, document_map,parallel_model.module.start_id, parallel_model.module.end_id, max_sentence_len=512)
            with torch.no_grad():
                test_scores_ab, _, test_probing_scores_ab,test_causal_scores_ab, _ , _ = \
                        predict_causal_counterpart(parallel_model, test_ab, test_ba, device, batch_size = 300, joint_model= True)

                test_labels = torch.LongTensor(test_labels)
                split = 'test'

                scores, _, probing_scores, _, causal_scores, _  = predict_causal_counterpart(parallel_model, dev, device, batch_size = 256, joint_model=True)
                    predictions = scores > 0.5
                    predictions = torch.squeeze(predictions)
                    probing_scores = torch.sigmoid(probing_scores)
                    causal_predictions=  causal_scores > 0.5
                    causal_predictions = torch.squeeze(causal_predictions)
                    probing_predictions =  probing_scores > 0.5
                    probing_predictions = torch.squeeze(probing_predictions)
                    folder = working_folder + f'./'
                    if not os.path.exists(folder):
                        os.makedirs(folder)
        
                    conf, final_scores, final_frame = get_probing_causal_counterpart_clusters(split,test_pairs, test_predictions, test_predictions, gold_map_test, dataset_folder )   
                    cluster_scores_probing_causal_pair.append(["probing", window,final_frame ])

            
                    final_coref_results.append([final_scores ])
                    final_coref_frame.append([conf])
                    final_conf.append([ final_frame])
                    top_range= len(predictions)
                    top_k_range_steps = [
                                    (50, 200, 5), 
                                    (200, 300, 10),  
                                    (300, 600, 10),(600, top_range, 50)  
                                ]
                    top_k = generate_custom_sequence(top_k_range_steps)

                    final_frame = []
                    looped_conllf1 = []

                    for k in top_k:
                        
                        conf, final_scores, final_frame = get_probing_causal_counterpart_clusters_pruned(split,pairs, predictions, predictions, causal_scores, probing_scores,gold_map, causal_probing_label, working_folder, k)
                        final_frame.extend(final_frame)
                        looped_conllf1.append([k, conf])
                        final_coref_results.append([k, final_scores ])
                        final_coref_frame.append([k, conf])
                        final_conf.append([k, final_frame])

                pickle.dump(scores, open(folder + f'/{split}_scores_{}_{n}.pkl', 'wb'))
                pickle.dump(pairs, open(folder + f'/{split}_pairs.pkl', 'wb'))
                pickle.dump(causal_scores, open(folder + f'/{split}_causal_scores_{}_{n}.pkl', 'wb'))
                pickle.dump(probing_scores, open(folder + f'/{split}_probing_scores_{}_{n}.pkl', 'wb'))   



                top_range= len(test_predictions)
                top_k_range_steps = [
                                (500, top_range, 10) # From 300 to 7000 with a step of 500 for less granularity
                            ]
                top_k = generate_custom_sequence(top_k_range_steps)

                final_frame = []
                looped_conllf1 = []
                for k in top_k:

                    conf, final_scores, final_frame = get_probing_causal_counterpart_clusters_pruned(split,test_pairs, test_predictions, test_predictions, causal_scores, probing_scores,gold_map_test, causal_probing_label_test, dataset_folder, k)
                    final_frame.extend(final_frame)
                    looped_conllf1.append([window, k, conf])
                    final_coref_frame.append([window,k, conf])
                    final_conf.append([window,k, final_frame])

        top_k_results = process_result_with_top_k_ablations(final_conf)
        window_results = process_result_with_top_k_ablations(cluster_scores_probing_causal_pair)
        return window_results, top_k_results

    window_results, top_k_results = get_cluster_scores()
    
    return window_results, top_k_results


if __name__ == '__main__':
    
    dataset = 'deli_data'
    dataset_folder = f'./datasets/{dataset}/'
    window_size = list(range(10, 60, 5))
    split = "test"
    test__folder = dataset_folder + f'/{split}/'
    window_results, top_k_results = get_joint_model_predictions(dataset =dataset, window = True )
    window_results.to_csv(test__folder + f"/final_predictions_with_window.csv")
    top_k_results.to_csv(test__folder + f"/final_predictions_top_k.csv")
   
    
    