import re
from pprint import pprint
import os
import gc
from typing import Dict, List, Union, Tuple, Type
import argparse
import logging
import ast
from concurrent.futures import ThreadPoolExecutor, as_completed

from flair.embeddings import (WordEmbeddings,
                              ELMoEmbeddings,
                              BertEmbeddings,
                              DocumentPoolEmbeddings,
                              Embeddings)
from flair.data import Sentence

import torch
from torch.nn.modules.distance import CosineSimilarity
from torch.nn import Module

import numpy as np
import pandas as pd

from tqdm import tqdm

logger = logging.getLogger(__name__)

class SimilarityComputer():
    """
    class to compute similarity between embeddings of two phrases in datasets.

    Parameters:

    - `input_dataset_path`: input dataset path containing in `column_pairs`
    - `sep`: separator char used in `input_dataset_path`
    - `column_pairs`: tuple (or list of tuples), where the tuples correspond to
    a pair of columns that will have the sentences embedding compared
    - `config`: a dictionary with the model info to use to generate the embedding.
    - `fn_comp_sim`: torch function to generate the similarity.
    - `mask_column_suffix`: the tag column suffix. If passed, the class will use
    the positional boolean tag to get the relevant tokens for generating the embedding.
    - `verbose`: print logs during the execution, useful for debugging purpose.
    """

    RESULT_COLUMN_MWE_SUFFIX = '-cs-mwe'
    RESULT_COLUMN_SENT_SUFFIX = '-cs-sent'

    def __init__(self,input_dataset_path: str, sep: str,
                      columns: Union[Tuple[str, str],
                                     List[Tuple[str, str]],
                                     Tuple[str, str, str],
                                     List[Tuple[str, str, str]]],
                      config: dict, fn_comp_sim: Type[Module],
                      mask_column_suffix: str = None, batch_size = 0,
                      mode: str = "pair",
                      verbose=False,
                      consider_not_found_tokens=True):

        self.dataset = pd.read_csv(input_dataset_path, sep=sep, index_col=0)
        self.columns = columns
        self.model = self.get_model_from_config(config)

        self.use_mask = True if mask_column_suffix else False
        self.mask_column_suffix = mask_column_suffix if self.use_mask else "_mask"

        if self.use_mask:
            columns_mask = list(filter(lambda x: self.mask_column_suffix in x, self.dataset.columns.to_list()))
            columns_mask_transf = {column: lambda x: ast.literal_eval(x) for column in columns_mask}
            self.dataset[columns_mask] = self.dataset.transform(columns_mask_transf)

        self.verbose = verbose
        self.fn_comp_sim = fn_comp_sim
        self.batch_size = batch_size
        self.consider_not_found_tokens = consider_not_found_tokens # only for GloVe
        if not self.consider_not_found_tokens and type(self.model) != WordEmbeddings:
            raise ValueError("code is not yet prepared to not consider unknown tokens from other models besides GloVe")

        self.guarantee_dataset_consistency()

        if mode != "pair" and mode != "difference":
            raise ValueError(f"mode should not be {mode}, it has to be either pair or difference")
        self.mode = mode

        super().__init__()

    def get_mask_column(self, column: str) -> str:
        return f'{column}{self.mask_column_suffix}'

    def check_columns_in_dataset(self, column: str):
        if column not in self.dataset.columns:
            raise ValueError(f"Column {column} is not present in dataset")

        if self.use_mask and self.get_mask_column(column) not in self.dataset.columns:
            raise ValueError(f"Column {column}{self.mask_column_suffix} is not present in dataset")

    def guarantee_dataset_consistency(self):
        for columns in self.columns:
            for column in columns:
                self.check_columns_in_dataset(column)

    def get_sentence_and_mask(self, row: pd.Series, column: str) -> Tuple[Sentence, List[bool]]:
        str_sent = row[column]
        mask_sent = row[self.get_mask_column(column)].copy()

        zeroed_tokens = []
        if not self.consider_not_found_tokens and \
            type(self.model) == WordEmbeddings:
            str_sent_splited = str_sent.split()
            zeroed_tokens = []
            original_size = len(str_sent_splited)
            for i, token in enumerate(str_sent_splited):
                if self.model.get_cached_vec(token).sum() == 0.0:
                    # only treat cases where it's an NC with separated by -
                    if '-' in token:
                        parts = token.split('-')
                        if len(parts) == 2 and parts[0] != "" and parts[1] != "":
                            logger.debug(f'after splitting {parts}')
                            # Split the word by dash, substitute the original position
                            # by the first part
                            str_sent_splited[i] = parts[0]

                            # Add a new one right after it
                            str_sent_splited.insert(i+1, parts[1])

                            # Repeat the same tag the token had before
                            mask_sent.insert(i+1, mask_sent[i])
                        else:
                            zeroed_tokens.append(i)
                    else:
                        zeroed_tokens.append(i)

            if original_size < len(str_sent_splited):
                str_sent_expanded = ' '.join(str_sent_splited)
                logger.debug(f'Find a token split case hence sentence went from "{str_sent}" of size {len(str_sent.split())} to '
                             f'"{str_sent_expanded}" of size {len(str_sent_splited)}')
                str_sent = str_sent_expanded

            if len(zeroed_tokens) > 0:
                logger.debug(f'From sentence {str_sent} the indices {zeroed_tokens} are zero')
                str_sent_nozeroes = ' '.join(np.delete(str_sent.split(), zeroed_tokens).tolist())
                logger.debug(f'Now the sentence is {str_sent_nozeroes}')
                str_sent = str_sent_nozeroes

        sent = Sentence(str_sent)

        if len(zeroed_tokens) > 0:
            logger.debug(f'From mask sentence {mask_sent} the indices {zeroed_tokens} are zero')
            mask_sent = np.delete(mask_sent, zeroed_tokens).tolist()
            logger.debug(f'Now the sentence is {mask_sent}')

        assert len(mask_sent) == len(sent), f'The mask {mask_sent} don\'t match size ' \
                                                f'of the sentence "{str_sent}"'

        if logger.isEnabledFor(logging.DEBUG):
            str_sent_mwe_np = np.asarray(str_sent.split())[mask_sent]
            logger.debug(f'the following words from the sentence "{str_sent}" '
                         f'were the result of masking: {str_sent_mwe_np}')
            logger.debug(f'Sentence object {sent}')

        return sent, mask_sent

    def get_embedding_size(self) -> int:
        return self.model.embedding_length

    def list_sentences_mask(self, low: int, high: int, column: str) -> (List[Sentence], List[List[bool]]):
        sents = []
        sents_mask = []
        for compound, row in self.dataset.iloc[low:high].iterrows():
            sent, mask_sent = self.get_sentence_and_mask(row, column)
            sents.append(sent)
            sents_mask.append(mask_sent)

        return sents, sents_mask

    def transform_to_torch(self, sents: List[Sentence], sents_mask: List[List[bool]]) -> (torch.Tensor, torch.Tensor):
        embedding_size = self.get_embedding_size()
        total_number_batch = len(sents)

        biggest_sentence = 0
        for sent in sents:
            biggest_sentence = max(biggest_sentence, len(sent))

        to_phrases = torch.zeros(total_number_batch, biggest_sentence, embedding_size)
        to_phrases_mask = torch.zeros(total_number_batch, biggest_sentence)
        to_phrases_length = torch.zeros(total_number_batch)

        for x, sent, mask_sent in zip(range(total_number_batch),
                                      sents, sents_mask):
            for y, token in enumerate(sent):
                to_phrases[x, y] = token.embedding

            to_mask = torch.zeros(len(mask_sent))
            to_mask[mask_sent] = 1.0
            to_phrases_mask[x, :len(mask_sent)] = to_mask
            to_phrases_length[x] = len(mask_sent)

        return to_phrases, to_phrases_mask, to_phrases_length

    def compute_similarity(self, output_dataset: pd.DataFrame, columns):
        """
        Receives a dataset and two columns to apply a cosine similarity between the
        sentence representation of masked members. Returns a copy from `dataset` including the
        columns.
        """

        # Creates a copy and a new column for the comparison function
        resulting_column_mwe = '-'.join(columns) + self.RESULT_COLUMN_MWE_SUFFIX
        output_dataset.insert(len(output_dataset.columns),
                              resulting_column_mwe,
                              float('nan'))

        resulting_column_sent = '-'.join(columns) + self.RESULT_COLUMN_SENT_SUFFIX
        output_dataset.insert(len(output_dataset.columns),
                              resulting_column_sent,
                              float('nan'))

        if self.batch_size > 0:
            logger.debug('calculating in batch mode')
            embedding_size = self.get_embedding_size()
            for i in range(0, len(self.dataset), self.batch_size):
                total_number_batch = min(len(self.dataset), i+self.batch_size) - i

                columns_sent_data = [self.list_sentences_mask(i, i+self.batch_size, column) for column in columns]

                model_input = []
                for data in columns_sent_data:
                    model_input.extend(data[0])

                self.model.embed(model_input)

                columns_sent_torch = [self.transform_to_torch(sent, sent_mask) for sent, sent_mask in columns_sent_data]

                columns_mwe_total = []
                for to_phrase, to_phrase_mask, _ in columns_sent_torch:
                    columns_mwe_total.append((to_phrase * to_phrase_mask.unsqueeze(2)).sum(dim=1) / to_phrase_mask.sum(dim=1, keepdim=True))

                # mean cannot be calculated directly as we may have empty spots due to
                # phrase length being lesser than the biggest sentence used as dimension
                columns_sent_total = [to_phrase.sum(dim=1) / to_phrase_length.unsqueeze(1) for to_phrase, _, to_phrase_length in columns_sent_torch]

                for to_mwe_total in columns_mwe_total:
                    assert to_mwe_total.shape == (total_number_batch, embedding_size)

                for to_sent_total in columns_sent_total:
                    assert to_sent_total.shape == (total_number_batch, embedding_size), to_sent_total.shape

                comp_results_mwe, comp_results_sent = self.calculate_by_mode(columns_mwe_total, columns_sent_total)

                for x, compound in zip(range(total_number_batch),
                                         self.dataset.iloc[i:i+total_number_batch].index):
                    output_dataset.at[compound, resulting_column_mwe] = comp_results_mwe[x]
                    output_dataset.at[compound, resulting_column_sent] = comp_results_sent[x]

        else:
            logger.debug('calculating in straightforward mode')
            for i, row in self.dataset.iterrows():
                cs_mwe, cs_sent = self.get_similarity(row, columns[0], columns[1])
                output_dataset.at[i, resulting_column_mwe] = cs_mwe
                output_dataset.at[i, resulting_column_sent] = cs_sent

    def calculate_by_mode(self, cols_mwe_total, cols_sent_total):
        logger.debug(f'{cols_mwe_total[0].shape}, {cols_mwe_total[1].shape}')
        if self.mode == "pair":
            logger.debug('executing comparison in pair mode')
            comp_results_mwe = self.fn_comp_sim()(cols_mwe_total[0],
                                                  cols_mwe_total[1]).detach().cpu().numpy()

            comp_results_sent = self.fn_comp_sim()(cols_sent_total[0],
                                                   cols_sent_total[1]).detach().cpu().numpy()
        elif self.mode == "difference":
            logger.debug('executing comparison in pair mode')
            comp_results_mwe = self.fn_comp_sim()(cols_mwe_total[0] - 1/2 * (cols_mwe_total[1] + cols_mwe_total[2]),
                                                  cols_mwe_total[1] - cols_mwe_total[2]).detach().cpu().numpy()

            comp_results_sent = self.fn_comp_sim()(cols_sent_total[0] - 1/2 * (cols_sent_total[1] + cols_sent_total[2]),
                                                   cols_sent_total[1] - cols_sent_total[2]).detach().cpu().numpy()

        return comp_results_mwe, comp_results_sent

    def get_cosine(self, sent_a: Sentence, mask_sent_a: List[bool],
                         sent_b: Sentence, mask_sent_b: List[bool]) -> float:
         # Transform to matrix (num_tokens, hidden_units)
        tc_sent_a = torch.stack([token.embedding for token in sent_a], dim=0)
        tc_sent_b = torch.stack([token.embedding for token in sent_b], dim=0)

        # Filter for the mask only
        tc_masked_sent_a = tc_sent_a[mask_sent_a].mean(dim=0, keepdim=True)
        tc_masked_sent_b = tc_sent_b[mask_sent_b].mean(dim=0, keepdim=True)

        # logging.debug(f'Resulting vector A is: "{tc_masked_sent_a}"')
        # logging.debug(f'Resulting vector B is: "{tc_masked_sent_b}"')
        logger.debug(f'tc_masked_sent_a has shape {tc_masked_sent_a.shape}')

        assert tc_masked_sent_a.shape == tc_masked_sent_b.shape, \
            f'{tc_masked_sent_a.shape} doesn\'t match shape {tc_masked_sent_b.shape}'

        cs = self.fn_comp_sim()(tc_masked_sent_a,
                                tc_masked_sent_b).detach().cpu().numpy()[0]
        return cs

    def get_similarity(self, row, str_column_a, str_column_b) -> (float, float):
        sent_a, mask_sent_a = self.get_sentence_and_mask(row, str_column_a)
        sent_b, mask_sent_b = self.get_sentence_and_mask(row, str_column_b)

        # logging.debug(f'Sentence A is: "{sent_a}"')
        # logging.debug(f'Sentence B is: "{sent_b}"')

        self.model.embed([sent_a, sent_b])
        cs_mwe = self.get_cosine(sent_a, mask_sent_a,
                                 sent_b, mask_sent_b)

        cs_sent = self.get_cosine(sent_a, [True] * len(sent_a),
                                  sent_b, [True] * len(sent_b))

        sent_a.clear_embeddings()
        sent_b.clear_embeddings()

        del sent_a, sent_b

        return cs_mwe, cs_sent

    def get_model_from_config(self, config: dict) -> Type[Embeddings]:
        if config['model'] == 'glove':
            return WordEmbeddings(config.get('pretrained', 'glove'))
        elif config['model'] == 'bert':
            return BertEmbeddings(config['pretrained'],
                                  config.get('layers', '-1,-2,-3,-4'),
                                  config.get('pooling', 'mean'))
        elif config['model'] == 'elmo':
            return ELMoEmbeddings(config['pretrained'])
        else:
            raise ValueError(f'The model {config["model"]} is not supported')

    def get_unique_columns(self) -> List[str]:
        unique_col = set()
        for columns in self.columns:
            for column in columns:
                if column not in unique_col:
                    unique_col.add(column)

                if self.use_mask:
                    if self.get_mask_column(column) not in unique_col:
                        unique_col.add(self.get_mask_column(column))

        return list(unique_col)

    def compare_similarity(self, copy_input: bool = False) -> pd.DataFrame:
        """
        Compare cossine similarity of a one or more pairs of string columns, using their embedded vector.

        The algorithm takes each pair of sentences of the `column_pairs`, generate their embeddings and
        output to a column name `<column A>-<column B>-cs`, using the tags with the parameter
        `mask_column_suffix` was set with the proper suffix, e.g. `_tag`, the init function.

        Parameters:

        - `copy_input`: copies the columns used to compare to the output `DataFrame`
        """

        columns = []
        if type(self.columns) is tuple:
            columns.append(self.columns)
        else:
            columns = self.columns

        output_dataset = None
        if copy_input:
            output_dataset = self.dataset[self.get_unique_columns()].copy()
        else:
            output_dataset = pd.DataFrame(index=self.dataset.index)

        for input_cols in tqdm(columns,
                                desc="Processing column pair ",
                                disable=not self.verbose):
            self.compute_similarity(output_dataset, input_cols)

        return output_dataset

    def clear_memory(self):
        if type(self.model) == WordEmbeddings:
            self.model.precomputed_word_embeddings = None
        elif type(self.model) == BertEmbeddings:
            self.model.model = None
        elif type(self.model) == ELMoEmbeddings:
            self.model.ee = None

def define_arguments() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Compare the embedding of pairs of columns in "
                                                 "a dataset using a similarity function.")

    parser.add_argument("dataset_path", help="Path to CSV dataset file.")
    parser.add_argument("-s", "--separator", help="Separator character of the dataset file.", default=",")
    parser.add_argument("-m", "--mask", help="Use mask column and use the suffix to reach the mask column.", default="_tag")
    parser.add_argument("-sim", "--similarity", help="Function to use as similarity comparison on the embeddings.",
                                                choices=['cosine'],
                                                default='cosine')
    parser.add_argument("columns", nargs="+", help="Pair of columns separated by comma")

    parser.add_argument("model",
                       help="Which model to use to generate the embeddings.")
    parser.add_argument("pretrained",
                       help="Name of the pretrained model. "
                            "Can be a path pointing to the weights.")
    parser.add_argument("-l", "--layers",
                       help="Layers to consider of the model to represent a token.",
                       type=str,
                       default='-1,-2,-3,-4')
    parser.add_argument("-po", "--pooling",
                       help="Pooling strategy of the layers.",
                       default='mean')

    parser.add_argument("output",
                        help="Activate verbose mode to output debug functionality.")

    parser.add_argument("-c", "--copy-input", action="store_true",
                        help="Copy the used input columns the to output result.")

    parser.add_argument("-b", "--batch", type=int, default=0,
                        help="Batch size. If 0, it executes sequentially each pair of sentences.")

    parser.add_argument("-M", "--mode", type=str, default="pair", choices=["pair", "difference"],
                        help="")

    parser.add_argument("-U", "--consider-unknown", action="store_true",
                        help="If specified, algorithm does not consider unknown tokens for calculation.")

    parser.add_argument("-v", "--verbose", action="store_true",
                        help="Activate verbose mode to output debug functionality.")

    return parser



def main():
    args = define_arguments().parse_args()

    supported_comp_functions = { 'cosine' : CosineSimilarity }

    columns = [tuple(comma_pair.split(',')) for comma_pair in args.columns]

    config = {}
    config = { "model": args.model,
               "pretrained": args.pretrained,
               "layers": args.layers,
               "pooling": args.pooling
             }

    logging_level = logging.INFO
    if args.verbose:
        logging_level = logging.DEBUG

    logging.basicConfig(level=logging_level)

    comp_function = supported_comp_functions[args.similarity]

    simcomp = SimilarityComputer(input_dataset_path=args.dataset_path,
                                 sep=args.separator,
                                 columns=columns,
                                 config=config,
                                 fn_comp_sim=comp_function,
                                 mask_column_suffix = args.mask,
                                 batch_size=args.batch,
                                 verbose=args.verbose,
                                 mode=args.mode,
                                 consider_not_found_tokens=args.consider_unknown
                                )

    output = simcomp.compare_similarity(args.copy_input)

    dirFileOutput = os.path.dirname(args.output)
    if dirFileOutput != "" and not os.path.exists(dirFileOutput):
        os.makedirs(dirFileOutput)
        logging.debug(f"Directory {dirFileOutput} created for file {args.output}.")
    else:
        logging.debug(f"Directory {dirFileOutput} already exists.")

    output.to_csv(args.output)

if __name__ == "__main__":
    main()
