# -*- coding: utf-8 -*-
"""
Created on Wed Dec 11 09:54:53 2019

@author: Administrator
"""
import os
from gensim.models import FastText
import _pickle as cPickle
from nltk import sent_tokenize, word_tokenize, ngrams, FreqDist
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np


def load_model(path, word2idx):
    model = FastText.load(path)
    Word2Vec = dict()
    for token, idx in word2idx.items():
        Word2Vec[idx] = model.wv[token.replace('__','')]
    return model, Word2Vec

def cal_metrics(sentences, embed_model, embeddings, word2idx, config, loss, ppl, epoch, mode):
    
    if mode == 'AB':
        coherence = cal_coherence(sentences[0], sentences[2], embeddings, word2idx, config)
        cal_data_id = 0
    elif mode == 'BB':
        distinct_1 = get_distinct(sentences[2], 1)
        distinct_2 = get_distinct(sentences[2], 2)
        distinct_3 = get_distinct(sentences[2], 3)
        cal_data_id = 1
                
    avge_embedding = get_embedding_average(sentences[cal_data_id], 
                                           sentences[2], 
                                           embed_model)
    
    greedy_embeddi = get_greedy_matching(sentences[cal_data_id], 
                                         sentences[2], 
                                         embed_model)
    
    vector_extrema = get_vector_extrema(sentences[cal_data_id], 
                                        sentences[2], 
                                        embed_model)
    
    embed_avg = sum([avge_embedding, greedy_embeddi, vector_extrema])/3
    
    bleu_1 = get_bleu(sentences[cal_data_id], sentences[2], 1)
    bleu_2 = get_bleu(sentences[cal_data_id], sentences[2], 2)
    bleu_3 = get_bleu(sentences[cal_data_id], sentences[2], 3)
    bleu_4 = get_bleu(sentences[cal_data_id], sentences[2], 4)
    bleu_avg = sum([bleu_1, bleu_2, bleu_3, bleu_4])/4
    
    if mode == 'AB':
        print('Metrics-Result-of Coherent:')
        print('  coherence: {:.4f}'.format(coherence))
        
        with open(os.path.join(config.logsdir, 'coherent_metrics_logs.txt'), 'a', encoding = 'utf-8') as f:
            f.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                epoch+1,
                loss,
                ppl,
                bleu_1, bleu_2, bleu_3, bleu_4,
                avge_embedding, greedy_embeddi, vector_extrema,
                bleu_avg, embed_avg, coherence
            ))
            
        metrics = [bleu_1, bleu_2, bleu_3, bleu_4, \
                avge_embedding, greedy_embeddi, vector_extrema,\
                bleu_avg, embed_avg, coherence]
    elif mode == 'BB':
        print('Metrics-Result-of Model:')
        print('  distinct: {:.4f}, {:.4f}, {:.4f}'.format(distinct_1, distinct_2, distinct_3))
        
        with open(os.path.join(config.logsdir, 'model_metrics_logs.txt'), 'a', encoding = 'utf-8') as f:
            f.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                epoch+1,
                loss,
                ppl,
                distinct_1, distinct_2, distinct_3,
                bleu_1, bleu_2, bleu_3, bleu_4,
                avge_embedding, greedy_embeddi, vector_extrema,
                bleu_avg, embed_avg
            ))
            
        metrics = [distinct_1, distinct_2, distinct_3, \
                    bleu_1, bleu_2, bleu_3, bleu_4, \
                    avge_embedding, greedy_embeddi, vector_extrema,\
                    bleu_avg, embed_avg]
        
    print('  bleu: {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}'\
          .format(bleu_avg, bleu_1, bleu_2, bleu_3, bleu_4))
    print('  embedding: {:.4f}, {:.4f}, {:.4f}, {:.4f}'\
          .format(embed_avg, avge_embedding, greedy_embeddi, vector_extrema))
    
    return metrics

def _response_tokenize(response):
    """
    Function: 将每个response进行tokenize
    Return: [token1, token2, ......]
    """
    response_tokens = []
#        vocab=self._get_vocab()
    for sentence in sent_tokenize(response):
        for token in word_tokenize(sentence):
           # if token in vocab:
            response_tokens.append(token)
    
    return response_tokens
    
    
def get_dp_gan_metrics(gen_responses):
    """
    Function：计算所有true_responses、gen_responses的
              token_gram、unigram、bigram、trigram、sent_gram的数量
    Return：token_gram、unigram、bigram、trigram、sent_gram的数量
    """
    responses = gen_responses 
    
    token_gram = []
    unigram = []
    bigram = []
    trigram = []
    sent_gram = []

    for response in responses:
        tokens = _response_tokenize(response)
        token_gram.extend(tokens)
        unigram.extend([element for element in ngrams(tokens, 1)])
        bigram.extend([element for element in ngrams(tokens, 2)])
        trigram.extend([element for element in ngrams(tokens, 3)])
        sent_gram.append(response)

    return len(token_gram), len(set(unigram)), len(set(bigram)), \
           len(set(trigram)), len(set(sent_gram))


def get_distinct(gen_responses, n):
    """
    Function: 计算所有true_responses、gen_responses的ngrams的type-token ratio 
    Return: ngrams-based type-token ratio 
    """
    ngrams_list = []
    token_gram = []
    responses = gen_responses 

    for response in responses:
        tokens = _response_tokenize(response)
        ngrams_list.extend([element for element in ngrams(tokens, n)])
    
    if len(ngrams_list) == 0:
        return 0
    else:
        return len(set(ngrams_list)) / len(ngrams_list)


def get_response_length(gen_responses):
    """ Reference:
         1. paper : Iulian V. Serban,et al. A Deep Reinforcement Learning Chatbot
    """
    response_lengths = []
    for gen_response in gen_responses:
        response_lengths.append(len(_response_tokenize(gen_response)))
    
    if len(response_lengths) == 0:
        return 0
    else:
        return sum(response_lengths)/len(response_lengths)


def get_bleu(true_responses, gen_responses, n_gram):
    """
    Function: 计算所有true_responses、gen_responses的ngrams的bleu

    parameters:
        n_gram : calculate BLEU-n, 
                 calculate the cumulative 4-gram BLEU score, also called BLEU-4.
                 The weights for the BLEU-4 are 1/4 (25%) or 0.25 for each of the 1-gram, 2-gram, 3-gram and 4-gram scores.
                 
    Reference:
        1. https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
        2. https://cloud.tencent.com/developer/article/1042161

    Return: bleu score BLEU-n
    """
    weights = {1:(1.0, 0.0, 0.0, 0.0),
               2:(1/2, 1/2, 0.0, 0.0),
               3:(1/3, 1/3, 1/3, 0.0),
               4:(1/4, 1/4, 1/4, 1/4)}
    total_score = []
    for true_response, gen_response in zip(true_responses, gen_responses):
        if len(_response_tokenize(gen_response))<=1:
            total_score.append(0)
            continue
        score = sentence_bleu(
                [_response_tokenize(true_response)], 
                _response_tokenize(gen_response),
                weights[n_gram],
                smoothing_function=SmoothingFunction().method7)
        total_score.append(score) 
        
    if len(total_score) == 0:
        return 0
    else:
        return sum(total_score) / len(total_score) 


def _consine(v1, v2):
    """
    Function：计算两个向量的余弦相似度
    Return：余弦相似度
    """
    if (np.linalg.norm(v1) * np.linalg.norm(v2)) == 0:
        return 0
    else:
        return np.dot(v1,v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
    
    
def get_greedy_matching(true_responses, gen_responses, word_vec):
    """
    Function: 计算所有true_responses、gen_responses的greedy_matching
    Return：greedy_matching 
    """
    total_cosine = []
    for true_response, gen_response in zip(true_responses, gen_responses):
        true_response_token_wv = np.array([word_vec.wv[item] for item in
                _response_tokenize(true_response)])
        gen_response_token_wv = np.array([word_vec.wv[item] for item in
                _response_tokenize(gen_response)])

        true_gen_cosine = np.array([[_consine(gen_token_vec, true_token_vec) 
            for gen_token_vec in gen_response_token_wv] for true_token_vec
            in true_response_token_wv]) 
        gen_true_cosine = np.array([[_consine(true_token_vec, gen_token_vec)
            for true_token_vec in true_response_token_wv] for gen_token_vec
            in gen_response_token_wv])
        
        try:
            true_gen_cosine = np.max(true_gen_cosine, 1)
            gen_true_cosine = np.max(gen_true_cosine, 1)
            cosine = (np.sum(true_gen_cosine) / len(true_gen_cosine) + np.sum(gen_true_cosine) / len(gen_true_cosine)) / 2
            total_cosine.append(cosine)
        except ValueError:
            total_cosine.append(0.0)
    
    if len(total_cosine) == 0:
        return 0
    else:
        return sum(total_cosine) / len(total_cosine) 


def get_embedding_average(true_responses, gen_responses, word_vec):
    total_cosine = []
    for true_response, gen_response in zip(true_responses, gen_responses):
        true_response_token_wv = np.array([word_vec.wv[item] for item in
                _response_tokenize(true_response)])
        gen_response_token_wv = np.array([word_vec.wv[item] for item in
                _response_tokenize(gen_response)])
    
        true_response_sentence_wv =  np.sum(true_response_token_wv, 0)
        gen_response_sentence_wv = np.sum(gen_response_token_wv, 0)
        true_response_sentence_wv = true_response_sentence_wv / np.linalg.norm(true_response_sentence_wv) 
        gen_response_sentence_wv =  gen_response_sentence_wv / np.linalg.norm(gen_response_sentence_wv)
        cosine = _consine(true_response_sentence_wv,
                gen_response_sentence_wv)
        total_cosine.append(cosine) 
    
    if len(total_cosine) == 0:
        return 0
    else:
        result = sum(total_cosine) / len(total_cosine)
        if isinstance(result, float):
            return result
        else:
            return 0.0


def get_vector_extrema(true_responses, gen_responses, word_vec):
    total_cosine = []
    for true_response, gen_response in zip(true_responses, gen_responses):
        true_response_token_wv = np.array([word_vec.wv[item] for item in
                _response_tokenize(true_response)])
        gen_response_token_wv = np.array([word_vec.wv[item] for item in
                _response_tokenize(gen_response)])
     
        try:
            true_sent_max_vec = np.max(true_response_token_wv, 0)
            true_sent_min_vec = np.min(true_response_token_wv, 0)
            true_sent_vec = []
            for max_dim, min_dim in zip(true_sent_max_vec, true_sent_min_vec):
                if max_dim > np.abs(min_dim):
                    true_sent_vec.append(max_dim)
                else:
                    true_sent_vec.append(min_dim)
            true_sent_vec = np.array(true_sent_vec)
    
            gen_sent_max_vec = np.max(gen_response_token_wv, 0)
            gen_sent_min_vec = np.min(gen_response_token_wv, 0)
            gen_sent_vec = []
            for max_dim, min_dim in zip(gen_sent_max_vec, gen_sent_min_vec):
                if max_dim > np.abs(min_dim):
                    gen_sent_vec.append(max_dim)
                else:
                    gen_sent_vec.append(min_dim)
            gen_sent_vec = np.array(gen_sent_vec)
    
            consine = _consine(true_sent_vec, gen_sent_vec)
            total_cosine.append(consine)
        except ValueError:
            total_cosine.append(0.0)
    
    if len(total_cosine) == 0:
        return 0
    else:
        return sum(total_cosine) / len(total_cosine)
    
    
def get_language_model(model_name, config):
    """
    :param model_name:
    :return:
    :function: 通过统计获得ngrams language model，使用Additive 1 smoothing
    """

    if os.path.exists(os.path.join('language_models', model_name)):
        # model has been saved
        with open(os.path.join('language_models', model_name), 'rb') as f:
            model_Prob_FreqDist = cPickle.load(f, encoding='bytes')
            return model_Prob_FreqDist

    else:
        # no model file, create a new model
        print('there no exist language model, creating...')
        train_data = []
        with open(config.train_data_path1, 'r', encoding = 'utf-8') as f1, open(config.train_data_path2, 'r', encoding = 'utf-8') as f2:
            for sent1, sent2 in zip(f1, f2):
                if len(sent1.split()) + 2 > config.maxlen1: continue # 1: </s>
                if len(sent2.split()) + 1 > config.maxlen2: continue  # 1: </s>
                train_data.append([sent1.strip(),sent2.strip()])
        
        if model_name == 'unigrams':
            unigramsFreqDist = FreqDist()
            
            for session in train_data:
                for sent in session:
                    sent_unigramsFreqDist = FreqDist(_response_tokenize(sent))
                    for j in sent_unigramsFreqDist:
                        if j in unigramsFreqDist:
                            unigramsFreqDist[j] += sent_unigramsFreqDist[j]
                        else:
                            unigramsFreqDist[j] = sent_unigramsFreqDist[j]

            model_Prob = FreqDist()
            for i in unigramsFreqDist:
                # Additive 1 smoothing
                model_Prob[i] = (unigramsFreqDist[i]+1) / (unigramsFreqDist.N() + unigramsFreqDist.B())
            Model = [model_Prob, unigramsFreqDist]

        # save model
        print('new model is created over')
        with open(os.path.join('language_models', model_name), 'wb') as f:
            cPickle.dump(Model, f)

        return Model
        
        
def weighted_average_sim_rmpc(We, x1, x2, w1, w2):
    """
    Compute the scores between pairs of sentences using weighted average + removing the projection on the first principal component
    :param We: We[i,:] is the vector for word i
    :param x1: x1[i, :] are the indices of the words in the first sentence in pair i
    :param x2: x2[i, :] are the indices of the words in the second sentence in pair i
    :param w1: w1[i, :] are the weights for the words in the first sentence in pair i
    :param w2: w2[i, :] are the weights for the words in the first sentence in pair i
    :return: scores, scores[i] is the matching score of the pair i
    """
    def get_weighted_average(We, x, w):
        """
        Compute the weighted average vectors
        :param We: We[i,:] is the vector for word i
        :param x: x[i, :] are the indices of the words in sentence i
        :param w: w[i, :] are the weights for the words in sentence i
        :return: emb[i, :] are the weighted average vector for sentence i
        """
        n_samples = x.shape[0]
        emb = np.zeros((n_samples, We[0].shape[0]))
        for i in range(n_samples):
            emb[i,:] = w[i,:].dot(np.array([We[token] for token in x[i]])) / np.count_nonzero(w[i,:])
        return emb
    
    emb1 = get_weighted_average(We, x1, w1)
    emb2 = get_weighted_average(We, x2, w2)

    inn = (emb1 * emb2).sum(axis=1)
    emb1norm = np.sqrt((emb1 * emb1).sum(axis=1))
    emb2norm = np.sqrt((emb2 * emb2).sum(axis=1))
    scores = inn / emb1norm / emb2norm
    return scores


def prepare_parameters(data, data_type, smooth_a, word2idx, config):
    
    def calculate_weight(data_tokens, data_lengths, smooth_a, wordProb):
        weights = np.zeros([len(data_tokens), max(data_lengths)])
        for i in range(len(data_tokens)):
            for j in range(len(data_tokens[i])):
                weights[i,j] = smooth_a / (smooth_a + wordProb[data_tokens[i][j]])
        return weights
    

    def tokens2idx(data, data_lengths, vocab_word2idx):
        X = np.zeros([len(data), max(data_lengths)], dtype=int)
        for i in range(len(data)):
            x = [vocab_word2idx.get('__'+token, 1) for token in data[i]]
            X[i, :data_lengths[i]] = x
        return X
    
    
    def create_input_of_average_embedding(data, data_type):
        tokens = []
        if data_type == 'context':
            for sent in data:
                tokens.append(_response_tokenize(sent))
        else: # true_response, or gen_responses
            for sent in data:
                tokens.append(_response_tokenize(sent))
        tokens_lengths = [len(x) for x in tokens]
        return tokens, tokens_lengths
    
    # 1 分词
    tokens, tokens_lengths = create_input_of_average_embedding(data, data_type)
    # 2 统计词频
    [unigramsProb, unigramsFreqDist]= get_language_model('unigrams', config)
    # 3 计算权重
    weights = calculate_weight(tokens, tokens_lengths, smooth_a, unigramsProb)
    # 4 处理数据
    X = tokens2idx(tokens, tokens_lengths, word2idx)
        
    return X, weights
    

def cal_coherence(context, gen_responses, word_vec, word2idx, config, smooth_a = 10e-3):
    """ Reference:
         1. paper : Better Conversations by Modeling, Filtering, and Optimizing for Coherence and Diversity
         2. paper : A SIMPLE BUT TOUGH-TO-BEAT BASELINE FOR SENTENCE EMBEDDINGS
         3. github : https://github.com/PrincetonML/SIF
    """
    X_context, weights_context = prepare_parameters(context, 'context', smooth_a, word2idx, config)
    X_respons, weights_respons = prepare_parameters(gen_responses, 'gen_responses', smooth_a, word2idx, config)
        
    coherences = weighted_average_sim_rmpc(word_vec,
                                           X_context,
                                           X_respons, 
                                           weights_context, 
                                           weights_respons)
    
    if len(coherences) == 0:
        return 0
    else:
        return sum(coherences)/len(coherences)