"""
The baselines that are in plan right now:

1. Tf-Idf weighted mean pooling, instead of mean pooling

2. Glove + max/mean pooling + k means

3. Bert + max/mean pooling + k means -> mean pooling

4. Bert sentence encoder + k means
https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/
"""

import argparse
import os
import pandas as pd
import numpy as np
import torch
import tensorflow as tf
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import kneighbors_graph
from sklearn import cluster
import logging
from sklearn.feature_extraction.text import TfidfVectorizer
import metrics
import utils
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

UCI_TOKEN = 'UCI_TOKEN'
pad_token = '<pad>'
unknown_token = '<unk>'


def glove_get_text_encoder(text, w2n, emb_matrix, pooling, weight=None):
    """
    :param text: text utterance
    :param w2n: word to number
    :param emb_matrix: embedding matrix
    :param pooling: pooling method
    :param weight: weight use tf idf score to weight each utterance contribution
    :return: weighted average/ mean pooling/ max pooling
    """
    text = str(text)
    word_embedding = []
    for word in text.split():
        if word in w2n:
            word_id = w2n[word]
            word_embedding.append(emb_matrix[word_id])
        elif unknown_token in w2n:
            word_id = w2n[unknown_token]
            word_embedding.append(emb_matrix[word_id])
        else:
            continue

    if weight is None:
        weight = [1]*len(word_embedding)
    if pooling == 'mean':
        se_array = np.average(word_embedding, weights=weight, axis=0)
    elif pooling == 'max':
        se_array = np.max(word_embedding, axis=0)
    return se_array


def bert_get_text_encoder(text, bert_model, pooling, weight=None):
    text = str(text)
    tokenized_text = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_ids = [1] * len(tokenized_text)
    segments_tensors = torch.tensor([segments_ids])
    with torch.no_grad():
        encoded_layers, _ = bert_model(tokens_tensor, segments_tensors)
    if pooling == 'mean':
        sentence_embedding = torch.mean(encoded_layers[11], 1)
    elif pooling == 'max':
        sentence_embedding = torch.max(encoded_layers[11], 1)
    se_array = sentence_embedding[0].numpy()
    return se_array


if __name__ == "__main__":
    # setting the hyper parameters, default path for cancel prime
    parser = argparse.ArgumentParser(description='baseline')
    parser.add_argument('--train_file', default='', type=str)
    parser.add_argument('--emb_file', default='', type=str)
    parser.add_argument('--method', default='bert', type=str, choices=['bert', 'glove', 'dsec', 'other'])
    parser.add_argument('--tfidf', action="store_true", dest='tfidf')
    parser.add_argument('--rm_stop_words', action="store_true", dest='rm_stop_words')
    parser.add_argument('--save_dir', default='results/baselines')
    parser.add_argument('--n_clusters', default=20, type=int)
    parser.add_argument('--nrows', default=None, type=int)
    parser.add_argument('--pooling', default='mean', type=str)
    parser.add_argument('--encoder_dir', default=None, type=str)
    parser.add_argument('--bert_dir', default=None, type=str)
    parser.add_argument('--algorithm',default="kmeans",type=str,
                        choices= ["kmeans", "AffinityPropagation",
                                  "SpectralClustering", "Ward",
                                  "AgglomerativeClustering" ,"DBSCAN", "Birch"])
    args = parser.parse_args()

    train_fname = args.train_file
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    logging.basicConfig(filename=args.save_dir + '/' + args.method +'_' + args.algorithm + '-cluster.log', filemode='a', level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(args)
    if args.nrows is None:
        train_df = pd.read_csv(train_fname, sep='\t')
    else:
        train_df = pd.read_csv(train_fname, sep='\t', nrows=args.nrows)  # 5000 only for testing

    if args.rm_stop_words:
        """
        python baselines.py --method glove --nrows 50000 --n_clusters 50 --pooling mean --rm_stop_words
        """
        from nltk.corpus import stopwords
        stopWords = set(stopwords.words('english'))
        clean_templates = train_df['Template'].apply(lambda x: ' '.join([w for w in str(x).split() if w not in stopWords]))
        train_df['Template'] = pd.Series(clean_templates, index=train_df.index)
        train_df = train_df[train_df['Template'] != '']

    cluster_id = train_df['Gold']
    cluster_id = np.array(cluster_id)

    if args.tfidf:
        # if use tf-idf as weights when pooling
        logging.info('Calculating Tf-Idf scores...')
        sent_list = [sent for sent in train_df['Template'].values.astype('U')]
        vectorizer = TfidfVectorizer(lowercase=False, token_pattern='(?u)\\b\\w+\\b')
        tfidf_matrix = vectorizer.fit_transform(sent_list)

    if args.method == 'bert':
        """
        python baselines.py --method bert --nrows 50000
        python baselines.py --method bert --nrows 50000 --pooling max
        """
        if args.encoder_dir is None:
            # call bert api to calculate the sentence embedding of each line
            logging.info('Loading Bert model...')
            if args.bert_dir is None:
                bert_model = BertModel.from_pretrained('bert-base-uncased')
            else:
                # use fine-tuned bert
                bert_model = torch.load(args.bert_dir)
                bert_model.cuda()
                logging.info('Loaded the fine-tuned bert model.')
            # bert_encoder = train_df['Template'].apply(lambda x: bert_get_text_encoder(x, bert_model))
            sen_encoder = train_df['Template'].apply(lambda x: bert_get_text_encoder(x, bert_model, args.pooling)).tolist()
            bert_encoder_df = pd.DataFrame([i[0] for i in sen_encoder])
            bert_encoder_df.to_csv(args.save_dir + '/' + args.method + '_encoder_' + args.pooling + str(args.nrows) + '.tsv',
                                   sep='\t', index=False, header=None)
            sen_encoder = bert_encoder_df
            logging.info('Loaded Bert model successfully.')
        else:
            """
            python baselines.py --method bert --nrows 50000 --pooling mean --encoder_dir results/baselines/bert_encoder_mean50000.tsv
            """
            bert_encoder_df = pd.read_csv(args.encoder_dir, sep='\t', header=None)
            sen_encoder = bert_encoder_df
    elif args.method == 'glove':
        """
        python baselines.py --method glove --nrows 50000 --n_clusters 50 --pooling mean
        python baselines.py --method glove --nrows 50000 --n_clusters 50 --pooling max
        python baselines.py --method glove --nrows 50000 --n_clusters 50 --pooling mean --tfidf
        """
        logging.info('Loading Glove embedding matrix...')
        _, _, emb_matrix, w2n, _, vocab_size = utils.get_data(train_df['Template'].to_frame(), 100, args.emb_file)
        if args.tfidf:
            sen_encoder = []
            for idx, row in train_df.iterrows():
                # text = row['Template']
                feature_index = tfidf_matrix[idx, :].nonzero()[1]  # extract non zero
                tfidf_text = ' '.join([vectorizer.get_feature_names()[i] for i in feature_index])
                tfidf_score = [tfidf_matrix[idx, x] for x in feature_index]
                sen_encoder.append(glove_get_text_encoder(tfidf_text, w2n, emb_matrix, args.pooling, weight=tfidf_score).tolist())
        else:
            sen_encoder = train_df['Template'].apply(lambda x: glove_get_text_encoder(x, w2n, emb_matrix, args.pooling)).values#.tolist()
        # glove_encoder_df = pd.DataFrame(sen_encoder)
        # glove_encoder_df.to_csv(args.save_dir + '/' + args.method + '_encoder_' + args.pooling + str(args.nrows) + '.tsv',
        #                        sep='\t', index=False, header=None)
        logging.info('Calculated the sentence encoder successfully.')
    elif args.method == 'dsec':
        logging.info('Loading the DSEC auto-encoder vector...')
        encoder_matrix = pd.read_csv(args.encoder_dir, sep='\t')
        # calculate the metrics for DSEC
        dsec_pred = encoder_matrix['Pred']
        dsec_gold = encoder_matrix['Gold']
        # q_threshold = 4 / args.n_clusters
        # eval_idx = np.where(q_max > q_threshold)[0]
        eval_idx = np.where(~np.isnan(dsec_gold))[0]
        eval_cluster_id, eval_cluster_pred = cluster_id[eval_idx], dsec_pred[eval_idx]
        acc = np.round(metrics.acc(eval_cluster_id, eval_cluster_pred), 5)
        mi = np.round(metrics.mi(eval_cluster_id, eval_cluster_pred), 5)
        nmi = np.round(metrics.nmi(eval_cluster_id, eval_cluster_pred), 5)
        ari = np.round(metrics.ari(eval_cluster_id, eval_cluster_pred), 5)
        model_performance_str = 'Acc=' + str(acc) + ', mi=' + str(mi) + ', nmi=' + str(nmi) + ', ari=' + str(ari)
        logging.info('=' * 50)
        logging.info('DSEC: ')
        logging.info(model_performance_str)
        logging.info('=' * 50)
        sen_encoder = encoder_matrix.loc[:, encoder_matrix.columns[3:]]  # remove 'Gold', 'Pred', 'q_ij'
        logging.info('Loaded the sentence encoder successfully.')
    elif args.method == 'other':
        """
        other means used a pre-calculated embedding matrix
        """

    alg_name = args.algorithm
    num_clusters = args.n_clusters
    sen_encoder = StandardScaler().fit_transform(sen_encoder)
    clustering_names = ['kmeans', 'AffinityPropagation', 'SpectralClustering', 'Ward', 'AgglomerativeClustering',
                        'DBSCAN', 'Birch']
    if alg_name not in clustering_names:
        print('The algorithm {} is not supported. Skipping clustering'.format(alg_name))
        cluster_name_str = "\t".join(clustering_names)
        print("choose one of the following algorithms: {}".format(cluster_name_str))
    else:
        print("Using algorithm {} for clustering".format(alg_name))


    # initializing multiple cluster algorithm
    two_means = cluster.KMeans(n_clusters=num_clusters, random_state=0)
    spectral = cluster.SpectralClustering(n_clusters=num_clusters, eigen_solver='arpack', affinity="nearest_neighbors")
    dbscan = cluster.DBSCAN(eps=.2)
    affinity_propagation = cluster.AffinityPropagation(damping=.9, preference=-200)
    birch = cluster.Birch(n_clusters=num_clusters)
    ward = None
    average_linkage = None

    if alg_name == 'Ward' or alg_name == 'AgglomerativeClustering':
        # connectivity matrix for structured Ward
        connectivity = kneighbors_graph(sen_encoder, n_neighbors=10, include_self=False)
        print("finish construct connectivity")
        # make connectivity symmetric
        connectivity = 0.5 * (connectivity + connectivity.T)
        ward = cluster.AgglomerativeClustering(n_clusters=num_clusters, linkage='ward', connectivity=connectivity)
        average_linkage = cluster.AgglomerativeClustering(linkage="average", affinity="cityblock",
                                                          n_clusters=num_clusters, connectivity=connectivity)
    print("finish initialization")
    clustering_algorithms = [two_means, affinity_propagation, spectral, ward, average_linkage, dbscan, birch]

    "Predict cluster id for each utterance. "
    for name, alg in zip(clustering_names, clustering_algorithms):
        if alg_name == name:
            alg.fit(sen_encoder)
            if hasattr(alg, 'labels_'):
                cluster_pred = alg.labels_
            else:
                cluster_pred = alg.predict(sen_encoder)

    eval_idx = np.where(~np.isnan(cluster_id))[0]
    eval_cluster_id, eval_cluster_pred = cluster_id[eval_idx], cluster_pred[eval_idx]
    acc = np.round(metrics.acc(eval_cluster_id, eval_cluster_pred), 5)
    mi = np.round(metrics.mi(eval_cluster_id, eval_cluster_pred), 5)
    nmi = np.round(metrics.nmi(eval_cluster_id, eval_cluster_pred), 5)
    ari = np.round(metrics.ari(eval_cluster_id, eval_cluster_pred), 5)

    model_performance_str = 'Acc=' + str(acc) + ', mi=' + str(mi) + ', nmi=' + str(nmi) + ', ari=' + str(ari)
    logging.info('=' * 50)
    logging.info(args.method + '+' + args.pooling + ':')
    logging.info(model_performance_str)
    logging.info('=' * 50)
    clust_df = pd.concat([pd.Series(cluster_id), pd.Series(cluster_pred)], axis=1)
    clust_df.columns = ['Gold', args.method+'_Pred']
    clust_df.to_csv(args.save_dir + '/' + args.method + '_' + args.pooling + '_' + args.algorithm\
                    + '_cluster_prediction.tsv',
                    sep='\t', index=False)







