import os
import json
from tqdm import tqdm
from datetime import datetime
from collections import OrderedDict
from typing import List, Dict, Tuple, Iterable, Type, Union, Optional, Set

import random
import numpy as np
import pandas as pd
from statistics import mean

import torch
from torch import nn, Tensor as T
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import transformers
from transformers import AdamW
from sentence_transformers import util
transformers.logging.set_verbosity_error()

from utils.data_utils import BSARDataset
from models.trainable_dense_models import BiEncoder



class BiEncoderTrainer(object):
    def __init__(self, 
                 model: nn.Module,
                 loss_fn: nn.Module,
                 queries_filepath: str,
                 documents_filepath: str,
                 batch_size: int, 
                 epochs: int,
                 learning_rate: float = 2e-5, 
                 weight_decay: float = 0.01,
                 scheduler_type: str = 'warmuplinear',
                 warmup_steps: int = 0,
                 log_steps: int = 10,
                 seed: int = 42, 
                 output_path: str = "output/training"):
        # Init trainer modules and parameters.
        self.model = model
        self.loss_fn = loss_fn
        self.batch_size = batch_size
        self.epochs = epochs
        self.lr = learning_rate
        self.weight_decay = weight_decay
        self.scheduler_type = scheduler_type.lower()
        self.warmup_steps = warmup_steps
        self.log_steps = log_steps
        self.seed = seed
        self.output_path = os.path.join(output_path, datetime.today().strftime('%b%d-%H-%M-%S'))

        # Seed, device, tensorboard writer.
        self.set_seed()
        self.device = self.set_device()
        self.writer = SummaryWriter()

        # Datasets.
        documents_df = pd.read_csv(documents_filepath)
        train_queries_df, val_queries_df = self.split_train_val(queries_filepath, train_frac=0.8)

        # Training Dataloader.
        train_dataset = BSARDataset(train_queries_df, documents_df)
        self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.model.collate_batch)

        # Evaluator.
        eval_dataset = BSARDataset(val_queries_df, documents_df)
        self.evaluator = BiEncoderEvaluator(queries=eval_dataset.queries, 
                                            documents=eval_dataset.documents, 
                                            relevant_pairs=eval_dataset.one_to_many_pairs, 
                                            score_fn=self.model.score_fn)

        # Optimizer and scheduler.
        self.optimizer = self.get_optimizer()
        self.scheduler = self.get_scheduler(t_total=len(self.train_dataloader)*self.epochs)
    
    def set_seed(self):
        """Ensure that all operations are deterministic on CPU and GPU (if used) for reproducibility.
        """
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
        # Some operations on a GPU are implemented stochastic for efficiency, change that.
        torch.backends.cudnn.determinstic = True
        torch.backends.cudnn.benchmark = False

    def set_device(self):
        return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    def split_train_val(self, queries_filepath: str, train_frac: float):
        # Load queries dataframe.
        df = pd.read_csv(queries_filepath)
        
        # Extract the duplicated questions to put them in the training set only.
        duplicates = df[df.duplicated(['question'], keep=False)]
        uniques = df.drop(duplicates.index)

        # Compute the fraction of unique questions to place in training set so that these questions completmented by the duplicates sums up to the given 'train_frac' ratio.
        train_frac_unique = (train_frac * df.shape[0] - duplicates.shape[0]) / uniques.shape[0]

        # Split the unique questions in train and val sets accordingly.
        train_unique = uniques.sample(frac=train_frac_unique, random_state=self.seed)
        val = uniques.drop(train_unique.index).sample(frac=1.0, random_state=self.seed)

        # Add the duplicated questions to the training set.
        train = pd.concat([train_unique, duplicates]).sample(frac=1.0, random_state=self.seed)

        # Reset indices and return.
        train.reset_index(drop=True, inplace=True)
        val.reset_index(drop=True, inplace=True)
        return train, val

    def get_optimizer(self):
        """Returns the AdamW optimizer that implements weight decay to all parameters other than bias and layer normalization terms.
        """
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_params = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        return transformers.AdamW(optimizer_grouped_params, lr=self.lr)

    def get_scheduler(self, t_total: int):
        """Returns the correct learning rate scheduler. 
        Available scheduler are: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts.
        """
        if self.scheduler_type == 'constantlr':
            return transformers.get_constant_schedule(self.optimizer)
        elif self.scheduler_type == 'warmupconstant':
            return transformers.get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=self.warmup_steps)
        elif self.scheduler_type == 'warmuplinear':
            return transformers.get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=t_total)
        elif self.scheduler_type == 'warmupcosine':
            return transformers.get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=t_total)
        elif self.scheduler_type == 'warmupcosinewithhardrestarts':
            return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=t_total)
        else:
            raise ValueError("Unknown scheduler {}".format(self.scheduler_type))

    def fit(self):
        # Move model and loss to device.
        self.model.to(self.device)
        self.loss_fn.to(self.device)

        # Init variables.
        global_step = 0
        num_batches = len(self.train_dataloader)
        num_samples = len(self.train_dataloader.dataset)

        # Training loop.
        for epoch in tqdm(range(self.epochs),  desc="Epoch"):

            train_loss, log_loss = 0.0, 0.0
            train_correct, log_correct = 0, 0

            self.model.train()
            for step, batch in enumerate(self.train_dataloader):

                # Step 1: Move input data to device.
                q_input_ids = batch['q_input_ids'].to(self.device)
                q_attention_masks = batch['q_attention_masks'].to(self.device)
                d_input_ids = batch['d_input_ids'].to(self.device)
                d_attention_masks = batch['d_attention_masks'].to(self.device)

                # Step 2: Run the model on the input data.
                scores = self.model(q_input_ids=q_input_ids, q_attention_masks=q_attention_masks, 
                                    d_input_ids=d_input_ids, d_attention_masks=d_attention_masks)

                # Step 3: Calculate the loss.
                labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device) #Tensor[batch_size] where x[i] = i (as query q[i] should match with document d[i]).
                loss = self.loss_fn(scores, labels)
                train_loss += loss.item()

                # Step 3': Calculate the number of correct predictions in the batch.
                max_score, max_idxs = torch.max(scores, dim=1)
                num_correct_preds = (max_idxs == labels).sum()
                train_correct += num_correct_preds.item()

                # Step 4: Perform backpropagation to calculate the gradients.
                self.optimizer.zero_grad()  #Always clear any previously calculated gradients before performing the backward pass
                loss.backward()

                # Step 5: Update the parameters and take a step using the computed gradients.
                nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  #Clip the gradients to 1.0 to prevent the "exploding gradients" problem
                self.optimizer.step()

                # Step 6: Update the learning rate.
                self.scheduler.step()

                # Log loss and accuracy.
                if self.log_steps > 0 and step != 0 and (step % self.log_steps) == 0:
                    loss_scalar = (train_loss - log_loss) / self.log_steps
                    acc_scalar = (train_correct - log_correct) / (self.log_steps * self.batch_size)

                    self.writer.add_scalar('Train/loss', loss_scalar, global_step)
                    self.writer.add_scalar('Train/acc', acc_scalar, global_step)
                    self.writer.add_scalar("Train/lr", self.scheduler.get_last_lr()[0], global_step)

                    log_loss = train_loss
                    log_correct = train_correct

                # Update global step.
                global_step += 1

            # Evaluate model after each epoch.
            self.evaluator(model=self.model, device=self.device, batch_size=self.batch_size*3, epoch=epoch, writer=self.writer)
            
            # Save the model.
            self.model.save(os.path.join(self.output_path, f"{epoch}"))

            # Report average loss and number of correct predictions.
            print(f'Epoch {epoch}: Train loss {(train_loss/num_batches):>8f} - Accuracy {(train_correct/num_samples*100):>0.1f}%')


class BiEncoderEvaluator(object):
    def __init__(self, 
                 queries: Dict[int, str], #qid -> query
                 documents: Dict[int, str],  #doc_id -> doc
                 relevant_pairs: Dict[int, List[int]], # qid -> List[doc_id]
                 score_fn: str,
                 recall_range: List[int] = [10, 20, 50, 100],
                 map_range: List[int] = [100],
                 mrr_range: List[int] = [100],
                 ):
        assert score_fn in ['dot', 'cos'], f"Unknown score function: {score_fn}"
        self.score_fn = util.dot_score if score_fn == 'dot' else util.cos_sim

        self.query_ids = list(queries.keys())
        self.queries = [queries[qid] for qid in self.query_ids]
        self.document_ids = list(documents.keys())
        self.documents = [documents[doc_id] for doc_id in self.document_ids]
        self.relevant_pairs = relevant_pairs

        self.recall_range = recall_range
        self.map_range = map_range
        self.mrr_range = mrr_range
    
    def __call__(self, 
                 model: Type[nn.Module], 
                 device: str, 
                 batch_size: int, 
                 writer: Optional[Type[SummaryWriter]] = None, 
                 epoch: Optional[int] = None
                 ):
        # Encode queries.
        query_embeddings = model.q_encoder.encode(texts=self.queries, device=device, batch_size=batch_size)
        document_embeddings = model.d_encoder.encode(texts=self.documents, device=device, batch_size=batch_size)

        # Retrieve top candidates.
        all_results = self.retrieve_documents(query_embeddings, document_embeddings)
        
        # Get ground truths.
        all_ground_truths = [self.relevant_pairs[qid] for qid in self.query_ids]

        # Compute metrics.
        scores = dict()
        for k in self.recall_range:
            recall_scalar = self.compute_mean_score(self.recall, all_ground_truths, all_results, k)
            if writer is not None:
                writer.add_scalar(f'Val/recall/recall_at_{k}', recall_scalar, epoch)
            scores[f'recall@{k}'] = recall_scalar

        for k in self.map_range:
            map_scalar = self.compute_mean_score(self.average_precision, all_ground_truths, all_results, k)
            if writer is not None:
                writer.add_scalar(f'Val/map/map_at_{k}', map_scalar, epoch)
            scores[f'map@{k}'] = map_scalar

        for k in self.mrr_range:
            mrr_scalar = self.compute_mean_score(self.reciprocal_rank, all_ground_truths, all_results, k)
            if writer is not None:
                writer.add_scalar(f'Val/mrr/mrr_at_{k}', mrr_scalar, epoch)
            scores[f'mrr@{k}'] = mrr_scalar
        
        return scores

    def retrieve_documents(self, query_embeddings: T, document_embeddings: T):
        max_k = max(max(self.recall_range), max(self.map_range), max(self.mrr_range))
        all_results = util.semantic_search(query_embeddings, document_embeddings, top_k=max_k, score_function=self.score_fn) #Returns a List[List[Dict[str,int]]]
        all_results = [[result['corpus_id'] for result in results] for results in all_results] #Extract the doc_id only -> List[List[int]]
        return all_results

    def compute_mean_score(self, score_func, all_ground_truths: List[List[int]], all_results: List[List[int]],  k: int = None):
        return mean([score_func(truths, res, k) for truths, res in zip(all_ground_truths, all_results)])

    def precision(self, ground_truths: List[int], results: List[int], k: int = None):
        k = len(results) if k is None else k
        relevances = [1 if d in ground_truths else 0 for d in results[:k]]
        return sum(relevances)/len(results[:k])

    def recall(self,ground_truths: List[int], results: List[int], k: int = None):
        k = len(results) if k is None else k
        relevances = [1 if d in ground_truths else 0 for d in results[:k]]
        return sum(relevances)/len(ground_truths)

    def fscore(self, ground_truths: List[int], results: List[int], k: int = None):
        p = self.precision(ground_truths, results, k)
        r = self.recall(ground_truths, results, k)
        return (2*p*r)/(p+r) if (p != 0.0 or r != 0.0) else 0.0

    def reciprocal_rank(self, ground_truths: List[int], results: List[int], k: int = None):
        k = len(results) if k is None else k
        return max([1/(i+1) if d in ground_truths else 0.0 for i, d in enumerate(results[:k])])

    def average_precision(self, ground_truths: List[int], results: List[int], k: int = None):
        k = len(results) if k is None else k
        p_at_k = [self.precision(ground_truths, results, k=i+1) if d in ground_truths else 0 for i, d in enumerate(results[:k])]
        return sum(p_at_k)/len(ground_truths)
        



if __name__ == '__main__':
    #---------------------------------------------------------------#
    #                       TRAINING
    #---------------------------------------------------------------#
    # 1. Initialize a new BiEncoder model to train.
    model = BiEncoder(is_siamese=True,
                      q_model_name_or_path='camembert-base',
                      truncation=True,
                      max_input_len=1000,
                      chunk_size=200,
                      window_size=20,
                      pooling_mode='cls',
                      score_fn='dot')

    ## 1'. OR load an already-trained BiEncoder.
    #model = BiEncoder.load("output/Oct27-17-56-22_siamese-flaubert-small-cased-512-512/99")

    # 2. Initialize the BiEncoder Trainer.
    trainer = BiEncoderTrainer(model=model, 
                               loss_fn=nn.CrossEntropyLoss(), 
                               queries_filepath="../../data/final/questions_fr_train.csv",
                               documents_filepath="../../data/final/articles_fr.csv",
                               batch_size=22, #NB: There are ~4500 training samples -> num_steps_per_epoch = 4500/batch_size = .
                               epochs=100,
                               warmup_steps=500,
                               log_steps=10)
    # 3. Launch training.
    trainer.fit()

    #---------------------------------------------------------------#
    #                       TESTING
    #---------------------------------------------------------------#
    # 1. Load the test set.
    test_queries_df = pd.read_csv("../../data/final/questions_fr_test.csv")
    documents_df = pd.read_csv("../../data/final/articles_fr.csv")
    test_dataset = BSARDataset(test_queries_df, documents_df)

    # 2. Initialize the Evaluator.
    evaluator = BiEncoderEvaluator(queries=test_dataset.queries, 
                                   documents=test_dataset.documents, 
                                   relevant_pairs=test_dataset.one_to_many_pairs, 
                                   score_fn=model.score_fn)

    # 3. Run trained model and compute scores.
    scores = evaluator(model=model,
                       device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
                       batch_size=512)

    # 4. Save results.
    os.makedirs(trainer.output_path, exist_ok=True)
    with open(os.path.join(trainer.output_path, 'test_scores.json'), 'w') as fOut:
        json.dump(scores, fOut, indent=2)






# Models:
# ---------------
# - 'flaubert/flaubert_small_cased'            : 54M     -> is_siamese: batch_size =    | is_dual_tower: batch_size = 
# - 'Geotrend/distilbert-base-fr-cased'        : 90-108M -> is_siamese: batch_size =    | is_dual_tower: batch_size = 
# - 'camembert-base'                           : 110M    -> is_siamese: batch_size = 22 | is_dual_tower: batch_size = 20
# - 'camembert/camembert-base-ccnet'           : 110M    -> is_siamese: batch_size = 22 | is_dual_tower: batch_size = 20
# - 'camembert/camembert-base-wikipedia-4gb'   : 110M    -> is_siamese: batch_size = 22 | is_dual_tower: batch_size = 20
# - 'flaubert/flaubert_base_cased'             : 138M    -> is_siamese: batch_size = 22 | is_dual_tower: batch_size = 20
