from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
from nltk.translate.bleu_score import SmoothingFunction
import numpy as np
import tqdm
import argparse
import time
import operator
from functools import reduce
from datasets import load_metric
from bert_score import BERTScorer
from fast_bleu import SelfBLEU, BLEU
import rouge
import json

import sys   
sys.setrecursionlimit(10000)

def geometric_mean(values):
    return (reduce(operator.mul, values)) ** (1.0 / len(values))

class Dist(object):
    def __init__(self):
        super(Dist, self).__init__()

    def _get_ngram(self, line, n):
        words = line.strip().split(" ")
        length = len(words)
        ngrams = []
        k = n - 1
        for i in range(0, length-k):
            ngrams.append(" ".join(words[i:i+k+1]))

        return ngrams

    def get_ngram_ratio(self, sents, n):
        dic = {}
        total_ngram_num = 1e-12
        for sent in sents:
            ngrams = self._get_ngram(sent.strip(), n)
            total_ngram_num += len(ngrams)

            for token in ngrams:
                dic[token] = 1

        #print (len(dic), total_ngram_num)
        ngram_ratio = len(dic) / float(total_ngram_num)
        return ngram_ratio

    def calculate(self, sents, max_order=4):
        
        dist_vec = [self.get_ngram_ratio(sents, n) for n in range(1, max_order+1)]

        dist = geometric_mean(dist_vec)

        return (dist, list(dist_vec))

class Jaccard(object):
    """docstring for Diversity"""
    def __init__(self, max_order=4):
        super(Jaccard, self).__init__()
        self._n = max_order # consider 1-guram to n-gram


    def _get_ngram(self, line, n):
        words = line.strip().split(" ")
        length = len(words)
        ngrams = []
        k = n - 1
        for i in range(0, length-k):
            ngrams.append(" ".join(words[i:i+k+1]))

        return ngrams

    '''
    jaccard similarity
    '''
    def _jaccardsim(self, set1, set2):
        jaccard = len(set1 & set2) / float(len(set1 | set2))

        return jaccard


    def get_ngram_jaccard(self, sents, n):
        # build inverted index
        data = []
        inverted_dic = {}
        lens = []
        for i, sent in enumerate(sents):
            ngrams = set(self._get_ngram(sent.strip(), n))
            data.append(ngrams)
            lens.append(len(ngrams))
            for ng in ngrams:
                if ng in inverted_dic:
                    inverted_dic[ng].append(i)
                else:
                    inverted_dic[ng] = [i]

        #----------------------------------------
        ans = []
        N = len(data)
        step = int(len(data) / 100)
        for i in range(N):
            data[i] = set(data[i])
        for i in range(0, N-1):
            '''
            if i % step == 0 and i != 0:
                print ("%.2f" % (float(i)/N))
            '''
            tokens1 = data[i]

            # get indices
            indices_set = set()
            for ng in tokens1:
                indices_set.update(inverted_dic[ng])

            cache = 0.0
            for j in indices_set:
                if j > i:
                    tokens2 = data[j]
                    intersection_num = len(tokens1 & tokens2)
                    cache += (intersection_num / (lens[i] + lens[j] - intersection_num))
            cache = cache / (N-i)

            ans.append(cache)

        ans = np.mean(ans)
        #ans = 1.0 - np.max(ans)
        return ans

    def calculate(self, data):

        js_vec = []
        for n in range(1, self._n+1):
            #print ("{}-gram diversity".format(n))
            js = self.get_ngram_jaccard(data, n)

            js_vec.append(js)


        # ----------------------------
        js = geometric_mean(js_vec)

        return js, list(js_vec)

def add2dic(key, dic):
    if key in dic:
        dic[key] += 1
    else:
        dic[key] = 1

class CND(object):
    def __init__(self, max_order=4):
        super(CND, self).__init__()
        self._n = max_order

    def build_reference(self, ref_sents):
        # build reference empirical distribution
        self._Pn = self._build_all_dist(ref_sents)    
    
    def _build_all_dist(self, sents):
        all_dist = []
        for gram_n in range(1, self._n+1):
            dist = self._build_dist_n(sents, gram_n)
            all_dist.append(dist)
        return all_dist

    def _build_dist_n(self, sents, n):
        dic = {}
        count = 0
        for line in sents:
            seq = line.strip()
            ngrams = self._get_ngram(seq, n)
            count += len(ngrams)
            for ng in ngrams:
                add2dic(ng, dic)

        distribution = {}
        count = float(count)
        for k, v in dic.items():
            distribution[k] = float(v) / count

        return distribution

    def _get_ngram(self, line, n):
        words = line.strip().split(" ")
        length = len(words)
        ngrams = []
        k = n - 1
        for i in range(0, length-k):
            ngrams.append(" ".join(words[i:i+k+1]))
        return ngrams

    def get_CR(self, Qn):
        assert self._Pn is not None
        all_CR = []
        for n in range(1, self._n+1):
            P = self._Pn[n-1]
            Q = Qn[n-1]
            keys = list(set(P.keys()) | set(Q.keys()))
            CR = 0.0
            for key in keys:
                v = P.get(key, 0) * Q.get(key, 0)
                CR += v
            all_CR.append(CR)
        return np.array(all_CR)

    def get_NRR(self, Qn):
        all_NRR = []
        for n in range(1, self._n+1):
            Q = Qn[n-1]

            if self._Pn is not None:
                P = self._Pn[n-1]
                keys = list(set(P.keys()) | set(Q.keys()))
            else:
                keys = Q.keys()

            NRR = 0.0
            for key in keys:
                v = Q.get(key, 0) ** 2
                NRR += v

            all_NRR.append(-NRR)

        return np.array(all_NRR)

    def get_CND(self, Qn):
        assert self._Pn is not None
        all_CND = []
        for n in range(1, self._n+1):
            P = self._Pn[n-1]
            Q = Qn[n-1]

            keys = list(set(P.keys()) | set(Q.keys()))

            CND = 0.0
            for key in keys:
                v =  (Q.get(key, 0) - P.get(key, 0))**2
                CND += v

            all_CND.append(CND)

        return np.array(all_CND)

    def calculate(self, sents, metrics=['cr, nrr, cnd'], scale=1e4):
        print(metrics)
        # build generated distribution
        Qn = self._build_all_dist(sents)
        
        CR_vec = self.get_CR(Qn)
        NRR_vec = self.get_NRR(Qn)
        CND_vec = self.get_CND(Qn)
        
        CR_vec *= float(scale)
        CR_exp_vec = np.exp(CR_vec)
        CR = geometric_mean(CR_vec)

        NRR_vec *= 100
        NRR_exp_vec = np.exp(NRR_vec)
        NRR = geometric_mean(NRR_exp_vec)

        CND_vec *= float(scale)
        CND_exp_vec = np.exp(CND_vec)
        CND = geometric_mean(CND_vec)
        
        results = {'CND':CND, 'CR':CR, 'NRR':NRR}
        return results

def filter_sen(sen):
    sen = sen.replace('<sep> ', '')
    sen = sen.replace('<eos>', '')
    sen = sen.replace('<|endoftext|>', '')
    return sen

def preprocess(hpy_file, ref_file):
    hpy = []
    ref = []
    with open(hpy_file) as f:
        for line in f:
            line = line.strip()
            line = filter_sen(line)
            hpy.append(line)
    with open(ref_file) as f:
        for line in f:
            line = line.strip()
            line = line.split('\t')[-1]
            ref.append(line)
    return hpy, ref

def clean_text(hpy_file):
    hpy = []
    with open(hpy_file) as f:
        for line in f:
            line = line.strip()
            line = filter_sen(line)
            hpy.append(' '.join(line.split()))
    with open(hpy_file, 'w') as f:
        f.write('\n'.join(hpy))

def test_unconditional_bleu(hpy0, ref0, args):
    hpy = [item.split() for item in hpy0]
    ref = [item.split() for item in ref0]

    weights = {'bigram': (1/2., 1/2.), 'trigram': (1/3., 1/3., 1/3.), '4gram': (1/4., 1/4., 1/4., 1/4.)}
    bleu = BLEU(ref, weights)
    scores = bleu.get_score(hpy)
    res = {}
    for key, ele in scores.items():
        res[key] = sum(ele) / len(ele)
    
    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append('bleu: ' + json.dumps(res))
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def test_self_bleu(hpy0, args):
    sents = [sent.split() for sent in hpy0]
    sbleu_weights = {'bigram': (1/2., 1/2.), 'trigram': (1/3., 1/3., 1/3.), '4gram': (1/4., 1/4., 1/4., 1/4.)}
    tool = SelfBLEU(sents, sbleu_weights)
    scores = tool.get_score()
    res = {}
    for key, ele in scores.items():
        res[key] = sum(ele) / len(ele)

    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append('self-bleu: ' + json.dumps(res))
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def test_dist(hpy0, args):
    dist_metric = Dist()
    dist = dist_metric.calculate(hpy0)
    dist = dist[0] * 100
    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append('dist-4: {}'.format(dist))
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def test_jaccard(hpy0, args):
    jaccard_metric = Jaccard()
    js, _ = jaccard_metric.calculate(hpy0)
    js = js * 100
    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append('Jaccard: {}'.format(js))
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def test_rouge(hpy, ref, args):
    rouge_metric = load_metric("rouge")
    rougew_metric = rouge.Rouge(metrics=['rouge-w'], limit_length=False, alpha=0.5, weight_factor=1.2, stemming=True)
    refs = [sent.strip() for sent in ref[0:len(hpy)]]
    sents = [sent.strip() for sent in hpy]
    # compute rouge
    rouge_types = ["rouge1", "rouge2", "rouge3", "rougeL"]
    rouge_results = rouge_metric.compute(predictions=sents, references=refs,
        use_stemmer=True, rouge_types=rouge_types)
    
    rouge1 = rouge_results['rouge1'].mid.fmeasure * 100
    rouge2 = rouge_results['rouge2'].mid.fmeasure * 100
    rouge3 = rouge_results['rouge3'].mid.fmeasure * 100
    rougel = rouge_results['rougeL'].mid.fmeasure * 100
    
    rougew_results = rougew_metric.get_scores(sents, refs)
    
    rougew = rougew_results['rouge-w']['f'] * 100
    metric = 'rouge-1: {}, rouge-2: {}, rouge-3: {}, rouge-l: {}, rouge-w: {}'.format(rouge1, rouge2, rouge3, rougel, rougew)
    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append(metric)
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def test_bleu(hpy, ref, args):
    refs = [[sent.strip().split(" ")] for sent in ref[0:len(hpy)]]
    sents = [sent.strip().split(" ") for sent in hpy]
    # coumpute bleu
    bleu_2 = corpus_bleu(refs, sents, weights=(1/2, 1/2))
    bleu_3 = corpus_bleu(refs, sents, weights=(1/3, 1/3))
    bleu_4 = corpus_bleu(refs, sents, weights=(1/4, 1/4))
    
    metric = 'bleu score bleu-2: {} bleu-3: {} bleu-4: {}'.format(bleu_2, bleu_3, bleu_4)
    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append(metric)
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def test_bertscore(hpy, ref, args, batch_size=8):
    bertscore_metric = BERTScorer(lang="en")
    refs = [[sent.strip()] for sent in ref[0:len(hpy)]]
    sents = [sent.strip() for sent in hpy]
    # compute bert score
    (all_p, all_r, all_f) = bertscore_metric.score(
        cands=sents, refs=refs, batch_size=batch_size
    )
    all_p = all_p.mean().item()*100
    all_r = all_r.mean().item()*100
    all_f = all_f.mean().item()*100
    
    metric = 'bert_score: roberta-large_L17_no-idf_version=0.3.10(hug_trans=4.10.0)-rescaled_fast-tokenizer P: {} R: {} F1: {}'.format(all_p, all_r, all_f)
    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append(metric)
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def test_cnd(hpy, ref, args):
    cnd = CND()
    cnd.build_reference(ref)
    metric = cnd.calculate(hpy)
    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append(json.dumps(metric))
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def test_mauve(hpy, ref, args):
    import mauve 
    out = mauve.compute_mauve(p_text=ref, q_text=hpy, device_id=0, max_text_length=256, verbose=False)
    mauve = out.mauve

    metric = 'mauve: {}'.format(mauve)
    write_name = args.hpy_file.split('.')[0] + '_metric.txt'
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(write_name, 'a') as f:
        output = args.hpy_file.split('.')[0].split('/')[-2:]
        output.append(metric)
        output.insert(0, now)
        output.insert(0, "")
        output.append("")
        f.write('\n'.join(output))
    print('\n'.join(output))

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--hpy_file", default=None, type=str,
                        help="hpy_file")
    parser.add_argument("--gd_file", default=None, type=str,
                        help="gd_file")
    parser.add_argument("--test_unconditional_bleu", action='store_true')
    parser.add_argument("--test_self_bleu", action='store_true')
    parser.add_argument("--test_cnd", action='store_true')
    parser.add_argument("--test_rouge", action='store_true')
    parser.add_argument("--test_bleu", action='store_true')
    parser.add_argument("--test_bertscore", action='store_true')
    parser.add_argument("--test_mauve", action='store_true')
    parser.add_argument("--test_dist", action='store_true')
    parser.add_argument("--test_jaccard", action='store_true')
    args = parser.parse_args()
    return args

def main():
    args = get_args()
    hpy, ref = preprocess(args.hpy_file, args.gd_file)
    if args.test_unconditional_bleu:
        test_unconditional_bleu(hpy, ref, args)
    if args.test_self_bleu:
        test_self_bleu(hpy, args)
    if args.test_jaccard:
        test_jaccard(hpy, args)
    if args.test_dist:
        test_dist(hpy, args)
    if args.test_cnd:
        test_cnd(hpy, ref, args)
    if args.test_bleu:
        test_bleu(hpy, ref, args)
    if args.test_bertscore:
        test_bertscore(hpy, ref, args)
    if args.test_rouge:
        test_rouge(hpy, ref, args)
    if args.test_mauve:
        test_mauve(hpy, ref, args)
if __name__ == "__main__":
    main()



