import math 
import sys
import os 
from typing import Dict, Iterable, List
import numpy as np
import json
import pickle

from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
from sklearn.metrics import average_precision_score

from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader

import sentence_transformers
from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation
from sentence_transformers.losses import SiameseDistanceMetric
from sentence_transformers.readers import InputExample

import wandb
import logging
from transformers import logging as lg

# currentdir = os.path.dirname(os.path.realpath(__file__))
# parentdir = os.path.dirname(currentdir)
# grandparentdir = os.path.dirname(parentdir)
# sys.path.append(parentdir)
# sys.path.append(grandparentdir)

# from nlp_utils.modified_sbert import losses, evaluation
from data_fns import prep_wikipedia_data


lg.set_verbosity_error()
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


class OnlineContrastiveLoss_wandb(sentence_transformers.losses.OnlineContrastiveLoss):

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor, size_average=False):
        embeddings = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]

        distance_matrix = self.distance_metric(embeddings[0], embeddings[1])
        negs = distance_matrix[labels == 0]
        poss = distance_matrix[labels == 1]

        # select hard positive and hard negative pairs
        negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())]
        positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())]

        positive_loss = positive_pairs.pow(2).sum()
        negative_loss = F.relu(self.margin - negative_pairs).pow(2).sum()
        loss = positive_loss + negative_loss

        wandb.log({"Loss": loss})

        return loss


class BinaryClassificationEvaluator_asymmetric(evaluation.BinaryClassificationEvaluator):
    """
    Update of BinaryClassificationEvaluator to cope with asymmetric input 
    """

    def compute_metrices(self, model):
        
        fps = [d['FP'] for d in self.sentences2]
        print(len(fps))
        fps_unique = list(set(fps))
        print(len(fps_unique))
        fps_unique_dict = [{'FP': sen} for sen in fps_unique]
        fp_embeddings_unique = model.encode(fps_unique_dict, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
        fp_emb_dict = {sent: emb for sent, emb in zip(fps_unique, fp_embeddings_unique)}
        embeddings2 = [fp_emb_dict[sent] for sent in fps]

        contexts = [d['CTX'] for d in self.sentences1]
        contexts_unique = list(set(contexts))
        contexts_unique_dict = [{'CTX': sen} for sen in contexts_unique]
        context_embeddings_unique = model.encode(contexts_unique_dict, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
        context_emb_dict = {sent: emb for sent, emb in zip(contexts_unique, context_embeddings_unique)}
        embeddings1 = [context_emb_dict[sent] for sent in contexts]

        cosine_scores = 1 - paired_cosine_distances(embeddings1, embeddings2)
        manhattan_distances = paired_manhattan_distances(embeddings1, embeddings2)
        euclidean_distances = paired_euclidean_distances(embeddings1, embeddings2)

        embeddings1_np = np.asarray(embeddings1)
        embeddings2_np = np.asarray(embeddings2)
        dot_scores = [np.dot(embeddings1_np[i], embeddings2_np[i]) for i in range(len(embeddings1_np))]

        labels = np.asarray(self.labels)
        output_scores = {}
        for short_name, name, scores, reverse in [['cossim', 'Cosine-Similarity', cosine_scores, True], ['manhattan', 'Manhattan-Distance', manhattan_distances, False], ['euclidean', 'Euclidean-Distance', euclidean_distances, False], ['dot', 'Dot-Product', dot_scores, True]]:
            # Note: newer versions of sbert have updated the spelling on manhatten to manhattan

            acc, acc_threshold = self.find_best_acc_and_threshold(scores, labels, reverse)
            f1, precision, recall, f1_threshold = self.find_best_f1_and_threshold(scores, labels, reverse)
            ap = average_precision_score(labels, scores * (1 if reverse else -1))

            logger.info("Accuracy with {}:           {:.2f}\t(Threshold: {:.4f})".format(name, acc * 100, acc_threshold))
            logger.info("F1 with {}:                 {:.2f}\t(Threshold: {:.4f})".format(name, f1 * 100, f1_threshold))
            logger.info("Precision with {}:          {:.2f}".format(name, precision * 100))
            logger.info("Recall with {}:             {:.2f}".format(name, recall * 100))
            logger.info("Average Precision with {}:  {:.2f}\n".format(name, ap * 100))

            output_scores[short_name] = {
                'accuracy': acc,
                'accuracy_threshold': acc_threshold,
                'f1': f1,
                'f1_threshold': f1_threshold,
                'precision': precision,
                'recall': recall,
                'ap': ap
            }

            wandb.log({
                f"Classification Accuracy {name}": acc,
                f"Classification Accuracy threshold {name}": acc_threshold,
                f"Classification F1 {name}": f1,
                f"Classification F1 threshold {name}": f1_threshold,
                f"Classification Precision {name}": precision,
                f"Classification Recall {name}": recall,
                f"Classification Average precision {name}": ap
            })

        return output_scores


def train_asymmetric(
        train_data: list = None,
        dev_data: list = None,
        pretrained_model=None,
        context_base_model=None,
        fp_base_model=None,
        train_batch_size=64,
        num_epochs=10,
        warm_up_perc=0.1,
        loss_params=None,
        model_save_path="output",
        wandb_names=None
):

    os.makedirs(model_save_path, exist_ok=True)

    # Logging
    if wandb_names: 
        if 'run' in wandb_names:
            wandb.init(project=wandb_names['project'], entity=wandb_names['id'], reinit=True, name=wandb_names['run'])
        else:
            wandb.init(project=wandb_names['project'], entity=wandb_names['id'], reinit=True)

        wandb.config = {
            "epochs": num_epochs,
            "batch_size": train_batch_size,
            "warm_up": warm_up_perc,
        }

    if pretrained_model:
        model = SentenceTransformer(pretrained_model)

    elif context_base_model and fp_base_model: 
        # Pretrained models for each side 
        context_model = SentenceTransformer(context_base_model)
        first_para_model = SentenceTransformer(fp_base_model)

        # Freeze context model 
        for param in context_model.parameters():
            param.requires_grad = False

        

        # Set up asymmetric biencoder - chilled version
        asym_model = sentence_transformers.models.Asym({'CTX': [context_model], 'FP': [first_para_model]})
        model = SentenceTransformer(modules=[asym_model])

        ###Check if context model is frozen in the model
        for param in model.parameters():
            print(param.requires_grad)
            

    else:
        raise ValueError 

    # # Set up asymmetric biencoder - fully specified version  
    # word_embedding_model = sentence_transformers.models.Transformer("sentence-transformers/all-MiniLM-L12-v2")

    # pooling_model = sentence_transformers.models.Pooling(word_embedding_model.get_word_embedding_dimension())

    # context_model = sentence_transformers.models.Dense(in_features=word_embedding_model.get_word_embedding_dimension(), out_features=256, bias=False, activation_function=nn.Identity())
    # first_para_model = sentence_transformers.models.Dense(in_features=word_embedding_model.get_word_embedding_dimension(), out_features=256, bias=False, activation_function=nn.Identity())

    # asym_model = sentence_transformers.models.Asym({'CTX': [context_model], 'FP': [first_para_model]})
    # model = SentenceTransformer(modules=[word_embedding_model, pooling_model, asym_model])

    # Loss
    print(loss_params['margin'])
    train_loss = OnlineContrastiveLoss_wandb(   
        model=model,
        distance_metric=loss_params['distance_metric'],
        margin=loss_params['margin']
    )

    # Load data 
    train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)

    # Evaluators 
    evaluators = [
        BinaryClassificationEvaluator_asymmetric.from_input_examples(dev_data),  
        # evaluation.ClusterEvaluator_wandb.from_input_examples(dev_data, cluster_type="agglomerative")  
    ]

    seq_evaluator = sentence_transformers.evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1])

    logger.info("Evaluate model without training")
    seq_evaluator(model, epoch=0, steps=0, output_path=model_save_path)

    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=seq_evaluator,
        epochs=num_epochs,
        warmup_steps=math.ceil(len(train_dataloader) * num_epochs * warm_up_perc),
        output_path=model_save_path,
        evaluation_steps=math.ceil(len(train_dataloader)/10),
        checkpoint_save_steps=math.ceil(len(train_dataloader)/10),
        checkpoint_path=model_save_path,
        save_best_model=True,
        checkpoint_save_total_limit=10
    )


def main():

    run = wandb.init()

    # Extract and featurise data. 
    train_data, dev_data, test_data = prep_wikipedia_data(
        dataset_path='/mnt/data01/wiki_data/train_splits/',
        model=SentenceTransformer('/mnt/data01/entity/trained_models/dates_newspaper_0.6249793562878543_128_5_0.381834161829256/'), 
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}, 
        featurisation='prepend', 
        disamb_or_coref='disamb',
        batch_type='contrastive_batchhard', 
        samples_per_label = 4, 
        batch_size=16,
        small=True
    )

    # Train
    if wandb.config.fp_base_model == 'newspapers':
        fp_base_model = '/mnt/data01/entity/trained_models/dates_newspaper_0.6249793562878543_128_5_0.381834161829256/' # Newspaper version, same as context
    elif wandb.config.fp_base_model == 'wikipedia':
        fp_base_model = '/mnt/data01/entity/trained_models/full_sentence-transformers/all-mpnet-base-v2_prepend_128_128_2_0.1196286633156809' # Wikipedia version 

    name = f'wikipedia_{wandb.config.fp_base_model}_{wandb.config.loss_margin}_{wandb.config.batch_size}_{wandb.config.epochs}_{wandb.config.warm_up_perc}'

    train_asymmetric(
        train_data,
        dev_data,
        context_base_model = '/mnt/data01/entity/trained_models/dates_newspaper_0.6249793562878543_128_5_0.381834161829256/',  # OLD VERSION
        fp_base_model = fp_base_model, 
        train_batch_size=wandb.config.batch_size,
        num_epochs=wandb.config.epochs,
        warm_up_perc=wandb.config.warm_up_perc,
        loss_params={'distance_metric': SiameseDistanceMetric.COSINE_DISTANCE, 'margin': wandb.config.loss_margin},
        model_save_path=f'/mnt/data01/entity/trained_models/asymmetric/{name}'
    )


if __name__ == '__main__':


    # # Config hyperparameter sweep
    # sweep_configuration = {
    #     'method': 'bayes',
    #     'name': 'sweep',   
    #     'metric': {'goal': 'maximize', 'name': "Classification F1 Cosine-Similarity"},
    #     'early_terminate': {'type': 'hyperband', 'min_iter': 10},    
    #     'parameters': 
    #         {
    #             'fp_base_model': {'values': ['newspapers', 'wikipedia']},
    #             'loss_margin': {'min': 0.1, 'max': 0.9},                            # Needs specifying for constrative and triplet loss
    #             'batch_size': {'values': [16, 32, 48, 64, 128]},                         # Too big for some models, but should just fail and move on 
    #             'epochs': {'min': 2, 'max': 5},                                                  
    #             'warm_up_perc': {'min': 0.0, 'max': 1.0}
    #         }
    # }

    # sweep_id = wandb.sweep(sweep=sweep_configuration, project='ent_dis_asym',  entity="emilys")

    # wandb.agent(sweep_id, project='ent_dis_asym',  entity="emilys", function=main, count=200)
   
    # Extract and featurise data. 
    train_data, dev_data, test_data = prep_wikipedia_data(
        dataset_path='/mnt/data01/wiki_data/train_splits/',
        model=SentenceTransformer('/mnt/data01/entity/best_models/full_100_newspapers_dates_newspaper_0.7089191069104632_128_4_0.11864373727448708/'), 
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}, 
        featurisation='prepend', 
        disamb_or_coref='disamb',
        batch_type='contrastive_batchhard', 
        samples_per_label = 4, 
        batch_size=16, 
        small=True
    )

    train_asymmetric(
        train_data,
        dev_data,
        context_base_model = '/mnt/data01/entity/best_models/full_100_newspapers_dates_newspaper_0.7089191069104632_128_4_0.11864373727448708/',  
        fp_base_model = '/mnt/data01/entity/best_models/all-mpnet-base-v2_prepend_128_128_2_0.1196286633156809/', 
        train_batch_size=128,
        num_epochs=5,
        warm_up_perc=0.4555,
        loss_params={'distance_metric': SiameseDistanceMetric.COSINE_DISTANCE, 'margin': 0.4978},
        model_save_path=f'/mnt/data01/entity/best_models/asymmetric/disambiguation_full_100_test',
        wandb_names={'project': "ent_dis_asym", "id": "econabhishek", "run": 'disambiguation_full_100'} 
    )
