# generate the pairs and start the training: for joint modelling 

import pickle
import torch
from helper_probing import tokenize, f1_score, accuracy, precision, recall, cluster, generate_key_file
from helper_deving import tokenize_utterances, forward, tokenize_wtd_utterances
from probing_prediction import predict_causal_counterpart
import random
from tqdm import tqdm
import os
from modeling_probing import Causal_Intervention_Scorer
from generate_gold_map_wtd import generate_gold_map
from delitoolkit.delidata import DeliData
import pickle
import json 
import pandas as pd
import numpy as np
import heapq
from coval.coval.conll.reader import get_coref_infos
from coval.coval.eval.evaluator import evaluate_documents as evaluate
from coval.coval.eval.evaluator import muc, b_cubed, ceafe, lea
from training_utils import *
import heapq
from collections import defaultdict


### Training scipt for the Joint Model. ####


def train_deliberation_model(dataset, model_name=None, joint_model = None, previous_window = None):

    final_coref = []
    final_f = []
    final_conf = []
    final_exp_results = []
    final_exp_results_ = []
    final_exp_results_dev = []
    dataset = dataset
    dataset_folder = f'./datasets/{dataset}/'
    device = torch.device('cuda:0')
    device_ids = list(range(1))
    
    gold_map_train, gold_map_dev, _, probing_map, document_map = generate_gold_map(dataset = dataset)
    if previous_window: # W means with previous_window = true, previous_window = None for no window ~ relaxed W 
        utterance_seq_dict_file = f'{dataset_folder}/wtd_with_utterance_new.pkl'
        with open(utterance_seq_dict_file, "rb") as f:
            utterance_sequence_map  = pickle.load(f)

        train_pairs, train_labels, causal_probing_label_train, zero_train = generate_pairs_for_train_eval(gold_map_train,utterance_sequence_map, split='train', previous_window =18)
        dev_pairs, dev_labels, causal_probing_label_dev, zero_dev = generate_pairs_for_train_eval(gold_map_dev,utterance_sequence_map, split = 'dev', previous_window =18)

    else:
        train_pairs, train_labels, causal_probing_label_train = generate_pairs_for_train_eval(gold_map_train, split='train')
        dev_pairs, dev_labels, causal_probing_label_dev = generate_pairs_for_train_eval(gold_map_dev, split = 'dev')



    scorer_module = Causal_Intervention_Scorer(is_training=True, long = True, model_name=model_name, joint_model=joint_model).to(device)

    parallel_model = torch.nn.DataParallel(scorer_module, device_ids=device_ids)
    parallel_model.module.to(device)
 

    #call trainin loop 
    conf, final_coref_results, final_coref_frame, conf_dev, coref_results_dev, coref_frame_dev\
        = train(sample_train_pairs, sample_train_labels,dev_pairs, dev_labels, \
            parallel_model, probing_map,document_map,gold_map_dev, dataset_folder, device,causal_probing_label_dev,
        batch_size=24, n_iters=16, lr_lm=0.000001, lr_class=0.0001,lr_p_c = 0.00001, joint_model = joint_modell)
        final_frame_df_dev = process_result(conf_dev)
        final_exp_results_dev.append(final_frame_df_dev)

    results_dev = pd.concat(final_exp_results_dev, ignore_index=True)
    return results_dev

def compute_joint_modeling_loss(scores_tuple, batch_labels, joint_model = True):
    bce_loss = torch.nn.BCELoss()
    mse_loss = torch.nn.MSELoss()
    pairwise_scores, probing_scores, causal_scores = scores_tuple
    
    if joint_model:
    #print("bidirectional scoring")
        total_scores = pairwise_scores + causal_scores + probing_scores 
        pairwise_sigmoid_output = pairwise_scores  
        pairwise_loss = bce_loss(pairwise_sigmoid_output, batch_labels)
        probing_sigmoid_output = torch.sigmoid(probing_scores) #sigmoided for loss unlike raw scores for pruning!
        causal_sigmoid_output =  torch.sigmoid(causal_scores)
        probing_loss = bce_loss(probing_sigmoid_output, batch_labels)
        causal_loss = bce_loss(causal_sigmoid_output, batch_labels)
        total_loss = pairwise_loss + 0.01*probing_loss + 0.01*causal_loss   # alpha regularization weights optimized with dev set tuning
 
    else: # this block executes the bidirectional baseline
        pairwise_sigmoid_output = (scores+scores )/2
        pairwise_loss = bce_loss(pairwise_sigmoid_output, batch_labels)
        pairwise_loss = bce_loss(pairwise_sigmoid_output, batch_labels.view(-1, 1))
 
    return pairwise_loss

            

def train(train_pairs,
          train_labels,
          dev_pairs,
          dev_labels,
          parallel_model,
          probing_map,
          document_map,
          gold_map_dev,
        working_folder,
          device,
          causal_probing_label_dev,
          batch_size=24,
          n_iters=16,
          lr_lm=0.00001,
          lr_class=0.0001, 
          lr_p_c = 0.00001, joint_model = None):
    bce_loss = torch.nn.BCELoss()
    mse_loss = torch.nn.MSELoss()
    # use this to initialize optimizer for bidiretional or CDLM 
#     optimizer = torch.optim.AdamW([
#         {'params': parallel_model.module.model.parameters(), 'lr': lr_lm},
#         {'params': parallel_model.module.linear.parameters(), 'lr': lr_class}
#     ])
 
    
    optimizer = torch.optim.AdamW([
        {'params': parallel_model.module.model.parameters(), 'lr': lr_lm},
        {'params': parallel_model.module.linear.parameters(), 'lr': lr_class},
        {'params': parallel_model.module.probing_intevention_layer.parameters(), 'lr': lr_p_c},
        {'params': parallel_model.module.causal_intevention_layer.parameters(), 'lr': lr_p_c},
        
        
    ])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) #not used in final runs
    tokenizer = parallel_model.module.tokenizer

    # prepare data
    train_tokenized, train_tokenized_bi = tokenize_wtd_utterances(tokenizer, train_pairs, probing_map, document_map,parallel_model.module.start_id, parallel_model.module.end_id,  max_sentence_len=512)
    dev_tokenized, dev_tokenized_bi = tokenize_wtd_utterances(tokenizer, dev_pairs, probing_map, document_map, parallel_model.module.start_id,parallel_model.module.end_id, max_sentence_len=512)
    # labels
    train_labels = torch.FloatTensor(train_labels)
    dev_labels = torch.LongTensor(dev_labels)
    final_coref_results_dev = []
    final_coref_frame_dev =  []
    final_conf_dev = []
    
    pairwise_loss = 0.
    iteration_loss_causal = 0.
    iteration_loss_probing = 0.
     
    val_loss_ = 0.

    for n in range(n_iters):
        val_scores = []
        dev_causal_scores = []
        dev_probing_scores  = []

        train_indices = list(range(len(train_pairs)))
        random.shuffle(train_indices)
        iteration_loss = 0.
        newtch_size = batch_size
        for i in tqdm(range(0, len(train_indices), newtch_size), desc='Training'):
            optimizer.zero_grad()
            batch_indices = train_indices[i: i + newtch_size]
            if joint_model: 
                scores, probing_scores, causal_scores = forward(parallel_model, train_tokenized, device, batch_indices, joint_model = joint_model)
                scores_tuple = (scores, probing_scores, causal_scores)
                batch_labels = train_labels[batch_indices].reshape((-1, 1)).to(device)
                loss = compute_joint_modeling_loss(scores_tuple, batch_labels, joint_model = True)
                loss.backward()
                optimizer.step()
                iteration_loss += loss.item()
                
            else: # for bidirectional or model
                scores = forward(parallel_model, train_tokenized, device, batch_indices)
                scores_bi = forward(parallel_model, train_tokenized_bi, device, batch_indices)

                batch_labels = train_labels[batch_indices].reshape((-1, 1)).to(device)

                scores_mean = (scores + scores_bi) / 2

                loss = bce_loss(scores_mean, batch_labels)

                loss.backward()

                optimizer.step()

                iteration_loss += loss.item()

            
        scheduler.step()
             
        with torch.no_grad():
            print(f'Iteration {n} Loss', iteration_loss / len(train_pairs, ))
            
            dev_scores, _, dev_probing_scores, _, dev_causal_scores, _  = predict_causal_counterpart(parallel_model, dev, device, batch_size = 256, joint_model=True)
            dev_predictions = dev_scores > 0.5
            dev_predictions = torch.squeeze(dev_predictions)
            probing_scores = torch.sigmoid(dev_probing_scores)
            causal_predictions=  causal_scores > 0.5
            causal_predictions = torch.squeeze(causal_predictions)
            probing_predictions =  probing_scores > 0.5
            probing_predictions = torch.squeeze(probing_predictions)

            split = "dev"
            dev_folder = working_folder + f'./'
            if not os.path.exists(dev_folder):
                os.makedirs(dev_folder)
 
            conf_dev, final_scores_dev, final_frame_dev = get_probing_causal_counterpart_clusters(split,dev_pairs, total_scores, total_scores, gold_map_dev, working_folder)
       
            final_coref_results_dev.append([n,final_scores_dev ])
            final_coref_frame_dev.append([n,conf_dev])
            final_conf_dev.append([ n,final_frame_dev])
            top_range= len(dev_predictions)
            top_range  = len(gold_map_dev)*5 # use 13 for WTD
            top_k_range_steps = [
                            (50, 200, 5), 
                            (200, 300, 10),  
                            (300, 600, 10),(600, top_range, 50)  
                        ]
            top_k = generate_custom_sequence(top_k_range_steps)

            final_frame = []
            looped_conllf1 = []

            for k in top_k:
                
                conf, final_scores, final_frame = get_probing_causal_counterpart_clusters_pruned(split,dev_pairs, dev_predictions, dev_predictions, causal_scores, probing_scores,gold_map_dev, causal_probing_label_dev, working_folder, k)
                final_frame.extend(final_frame)
                looped_conllf1.append([k, conf])
                final_coref_results.append([n,k, final_scores ])
                final_coref_frame.append([n,k, conf])
                final_conf.append([n,k, final_frame])

            pickle.dump(dev_scores, open(dev_folder + f'/{split}_scores_{}_{n}.pkl', 'wb'))
            pickle.dump(dev_pairs, open(dev_folder + f'/{split}_pairs.pkl', 'wb'))
            pickle.dump(causal_scores, open(dev_folder + f'/{split}_causal_scores_{}_{n}.pkl', 'wb'))
            pickle.dump(probing_scores, open(dev_folder + f'/{split}_probing_scores_{}_{n}.pkl', 'wb'))   

        if n % 4 == 0:
            joint_model_folder = dev_folder + f'/probing_scorer/chk_{n}'
            if not os.path.exists(joint_model_folder):
                os.makedirs(joint_model_folder)
            model_path = joint_model_folder + '/linear.chkpt'
            probing_intevention_layer_path = joint_model_folder + '/probing_intevention_layer.chkpt'
            causal_intevention_layer_path = joint_model_folder + '/causal_intevention_layer.chkpt'
            torch.save(parallel_model.module.linear.state_dict(), model_path)
            torch.save(parallel_model.module.probing_intevention_layer.state_dict(), probing_intevention_layer_path)
            torch.save(parallel_model.module.causal_intevention_layer.state_dict(), causal_intevention_layer_path)
            parallel_model.module.model.save_pretrained(joint_model_folder + '/bert')
            parallel_model.module.tokenizer.save_pretrained(joint_model_folder + '/bert')
            print(f'saved model at {n}')

    joint_model_folder = dev_folder + '/probing_scorer/'
    if not os.path.exists(joint_model_folder):
        os.makedirs(joint_model_folder)
    model_path = joint_model_folder + '/linear.chkpt'
    probing_intevention_layer_path = joint_model_folder + '/probing_intevention_layer.chkpt'
    causal_intevention_layer_path = joint_model_folder + '/causal_intevention_layer.chkpt'
    torch.save(parallel_model.module.linear.state_dict(), model_path)
    torch.save(parallel_model.module.probing_intevention_layer.state_dict(), probing_intevention_layer_path)
    torch.save(parallel_model.module.causal_intevention_layer.state_dict(), causal_intevention_layer_path)
    parallel_model.module.model.save_pretrained(joint_model_folder + '/bert')
    parallel_model.module.tokenizer.save_pretrained(joint_model_folder + '/bert')
    return final_conf, final_coref_results, final_coref_frame 


def main(dataset):

     # use previous_window = 18 for DeliData and 9 for WTD
    results = train_deliberation_model('wtd_dataset', model_name='allenai/longformer-base-4096', joint_model = True, previous_window = 18)  
    dataset = 'wtd_dataset' 
    dataset_folder = f'./datasets/{dataset}/'
    output_folder = dataset_folder + f'/final_run/'

    results_dev.to_csv(output_folder + f"/{dataset}_scores.csv")

    


if _name_ == '_main_':
    main(dataset = 'wtd_dataset')  
  

